From 91284f07e9a2ae35a00186d39248923aac26aa42 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 12 Aug 2024 08:28:30 +0000 Subject: [PATCH 01/13] broadcast attention_bias dim 0 and 1 --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 42 +- .../jsep/webgpu/ops/multihead-attention.ts | 121 +-- .../test/data/ops/multihead-attention.jsonc | 36 +- .../contrib_ops/cpu/bert/attention_base.cc | 54 +- .../contrib_ops/cpu/bert/attention_base.h | 4 +- .../contrib_ops/cpu/bert/attention_common.h | 6 +- .../contrib_ops/cpu/bert/attention_cpu_base.h | 123 +-- .../cpu/bert/multihead_attention.cc | 10 +- .../cpu/bert/multihead_attention_helper.h | 79 +- .../contrib_ops/cpu/utils/console_dumper.h | 38 + .../contrib_ops/cpu/utils/dump_tensor.cc | 29 + .../contrib_ops/cpu/utils/dump_tensor.h | 5 + .../contrib_ops/cuda/bert/attention.cc | 19 +- .../contrib_ops/cuda/bert/attention_impl.cu | 50 +- .../contrib_ops/cuda/bert/attention_impl.h | 5 +- .../cuda/bert/attention_prepare_qkv.cu | 12 +- .../cuda/bert/attention_softmax.cu | 715 ++++++++++-------- .../contrib_ops/cuda/bert/attention_softmax.h | 13 +- .../bert/cutlass_fmha/fmha_launch_template.h | 22 +- .../cutlass_fmha/memory_efficient_attention.h | 10 +- .../cuda/bert/decoder_attention_impl.cu | 14 +- .../cuda/bert/group_query_attention_impl.cu | 1 - .../cuda/bert/multihead_attention.cc | 21 +- .../contrib_ops/cuda/bert/packed_attention.cc | 60 +- .../cuda/bert/packed_attention_impl.cu | 13 +- .../cuda/bert/packed_attention_impl.h | 2 +- .../cuda/bert/packed_multihead_attention.cc | 65 +- .../bert/packed_multihead_attention_impl.cu | 12 +- .../bert/packed_multihead_attention_impl.h | 3 +- .../cuda/utils/dump_cuda_tensor.cc | 42 + .../contrib_ops/cuda/utils/dump_cuda_tensor.h | 14 +- .../contrib_ops/rocm/bert/attention.cu | 8 +- ...ed_gemm_softmax_gemm_permute_pipelines.cuh | 3 +- .../rocm/bert/multihead_attention.cu | 12 +- .../core/graph/contrib_ops/bert_defs.cc | 5 +- .../External/DirectMLHelpers/DirectMLSchema.h | 4 +- .../DirectMLHelpers/GeneratedSchemaHelpers.h | 4 +- .../src/Operators/DmlOperatorAttention.cpp | 34 +- .../DmlOperatorMultiHeadAttention.cpp | 24 +- .../src/Operators/DmlOperatorQAttention.cpp | 6 +- .../python/tools/transformers/constants.py | 4 +- .../transformers/convert_to_packing_mode.py | 20 +- .../contrib_ops/attention_op_test_helper.cc | 138 ++-- .../contrib_ops/attention_op_test_helper.h | 20 +- .../multihead_attention_op_test.cc | 54 +- .../multihead_attention_op_test_data_gen.py | 6 +- .../packed_multihead_attention_op_test.cc | 76 +- .../test/python/transformers/benchmark_mha.py | 82 +- .../test/python/transformers/test_mha.py | 235 ++++-- .../attention/attention_test_data.txt | 76 +- .../packed_multihead_attention_test_data.txt | 34 +- 51 files changed, 1435 insertions(+), 1050 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 435267a1b9652..31a8823447f0d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -92,7 +92,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte // bias (Q/K/V) : (D + D + D_v) // mask_index : see below // past (K/V) : (2, B, N, P, H) or NULL - // relative_position_bias : (B, N, S, T) or NULL + // attention_bias : (B, N, S, T) or NULL // For mask_index, the following shapes are supported: // NULL, (B, 1), (1, 1) @@ -109,10 +109,10 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte const bias = inputs[2]; const maskIndex = inputs[3]; const past = inputs[4]; - const relativePositionBias = inputs[5]; + const attentionBias = inputs[5]; - if (past && relativePositionBias) { - throw new Error('Attention cannot have both past and relative_position_bias'); + if (past && attentionBias) { + throw new Error('Attention cannot have both past and attention_bias'); } if (input.dims.length !== 3) { @@ -208,6 +208,20 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte throw new Error('past is not supported'); } + if (attentionBias) { + if (attentionBias.dims.length !== 4) { + throw new Error('Input "attention_bias" must have 4 dimensions'); + } + + // TODO: support broadcasting the first and second dimensions of attention_bias + if (attentionBias.dims[0] !== batchSize || + attentionBias.dims[1] !== attributes.numHeads || + attentionBias.dims[2] !== sequenceLength || + attentionBias.dims[3] !== totalSequenceLength) { + throw new Error('Input "attention_bias" shape shall be (batch_size, num_heads, sequence_length, total_sequence_length)'); + } + } + return { batchSize, sequenceLength, @@ -336,7 +350,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor const createAttentionProbsProgramInfo = (context: ComputeContext, q: TensorView, key: TensorView, pastKey: TensorView|undefined, - relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs, + attentionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs, pastSequenceLength: number) => { const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; @@ -367,7 +381,7 @@ const createAttentionProbsProgramInfo = if (pastKey) { inputDependencies.push('type'); } - if (relativePositionBias) { + if (attentionBias) { inputDependencies.push('type'); } const outputs = [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}]; @@ -382,9 +396,9 @@ const createAttentionProbsProgramInfo = const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components); inputVars.push(pastKeyInput); } - if (relativePositionBias) { + if (attentionBias) { inputVars.push( - inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims)); + inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims)); } const output = outputVariable('output', q.dataType, probsShape); const outputVars = [output]; @@ -473,14 +487,14 @@ const createAttentionProbsProgramInfo = } })()}; output[outputIdx] = ${output.type.value} (sum * uniforms.alpha) + ${ - relativePositionBias ? 'relative_position_bias[outputIdx]' : '0.0'}; + attentionBias ? 'attention_bias[outputIdx]' : '0.0'}; } }`; }; return { name: 'AttentionProbs', shaderCache: { - hint: `${components};${relativePositionBias !== undefined};${pastKey !== undefined};${context.outputCount}`, + hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${context.outputCount}`, inputDependencies }, getRunData: () => ({outputs, dispatchGroup: dispatch, programUniforms}), @@ -614,21 +628,21 @@ const createVxAttentionScoreProgramInfo = export const applyAttention = (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined, _past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined, - relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { + attentionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { const outputCount = context.outputCount; const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0; const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const inputsK = (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey) ? [q, k, pastKey] : [q, k]; - if (relativePositionBias) { - inputsK.push(relativePositionBias); + if (attentionBias) { + inputsK.push(attentionBias); } // Run AttentionProbs const probs = context.compute( createAttentionProbsProgramInfo( - context, q, k, outputCount > 1 ? pastKey : undefined, relativePositionBias, parameters, attributes, + context, q, k, outputCount > 1 ? pastKey : undefined, attentionBias, parameters, attributes, pastSequenceLength), {inputs: inputsK, outputs: (parameters.kvNumHeads === undefined && outputCount > 1) ? [-1, 1] : [-1]})[0]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 09fadea66fa1f..c83bf1481e109 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -20,53 +20,60 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr const value = getInput(inputs, 2); const bias = getInput(inputs, 3); const keyPaddingMask = getInput(inputs, 4); - const relativePositionBias = getInput(inputs, 5); + const attentionBias = getInput(inputs, 5); const pastKey = getInput(inputs, 6); const pastValue = getInput(inputs, 7); - // Abbreviation and Meanings: - // B: batch_size - // S: sequence_length (input sequence length of query) - // P: past_sequence_length (past sequence length of key or value) - // L: kv_sequence_length (input sequence length of key or value) - // M: max_sequence_length - // T: total_sequence_length = past_sequence_length + kv_sequence_length - // N: num_heads - // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size - // H_v: v_head_size - // D_i: input hidden size - // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size - // D_v: v_hidden_size = num_heads * v_head_size - - // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None - // relative_position_bias : (B, 1, S, L) - // past_key : (B, N, S*, H) - // past_value : (B, N, S*, H) - // When no packing for q/k/v: + // --------------------------------------------------------------- + // Notations: + // B: batch_size + // N: num_heads + // H: head_size of Q and K + // H_v: head_size of V + // D: hidden_size for Q and K, where D = N * H + // D_v: hidden_size of V, where D_v = N * H_v + // S: q_sequence_length + // P: past_sequence_length of kv cache + // L: kv_sequence_length + // T: total_sequence_length = P + L + // M: max_sequence_length of kv cache when past and present share buffer + // --------------------------------------------------------------- + // MultiHeadAttention inputs: + // --------------------------------------------------------------- + // Q_K_V_BSNH - no packing: // query (Q) : (B, S, D) - // key (K) : (B, L, D) or (B, N, S*, H) - // value (V) : (B, L, D_v) or (B, N, S*, H) - // bias (Q/K/V) : (D + D + D_v) - // When packed kv is used: + // key (K) : (B, L, D) + // value (V) : (B, L, D_v) + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v): // query (Q) : (B, S, D) - // key (K) : (B, L, N, 2, H) - // value (V) : None - // bias (Q/K/V) : None - // When packed qkv is used: - // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) - // key (K) : None - // value (V) : None + // key (K) : (B, N, L, H) + // value (V) : (B, N, L, H_v) + // Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv): + // query (Q) : (B, S, D) + // key (K/V) : (B, L, N, 2, H) + // value : None + // QKV_BSN3H - packed qkv (kv cache is not used, S == L, D == D_v): + // query (Q/K/V) : (B, S, N, 3, H) + // key : None + // value : None + // + // Other inputs: // bias (Q/K/V) : None or (D + D + D_v) + // key_padding_mask (K/V) : (B) or (3 * B + 2) or (B, T) or (B, S, T) + // attention_bias : None or (B, N, S, T), (1, N, S, T), (B, 1, S, T) or (1, 1, S, T) + // past_key : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // past_value : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // + // Not Supported: + // key_padding_mask, packed kv, packed qkv, and broadcast for attention_bias. if (query.dims.length !== 3 && query.dims.length !== 5) { throw new Error('Input query is expected to have 3 or 5 dimensions'); } - const dmmhaPacking = false; const batchSize = query.dims[0]; const sequenceLength = query.dims[1]; - const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : - attributes.numHeads * query.dims[4]; + const hiddenSize = query.dims.length === 3 ? query.dims[2] : (attributes.numHeads * query.dims[4]); let kvSequenceLength = sequenceLength; let pastSequenceLength = 0; @@ -127,14 +134,14 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); } - qkvFormat = AttentionQkvFormat.unknown; + qkvFormat = AttentionQkvFormat.unknown; // Q_K_V_BSNH_BNSH_BNSH kvSequenceLength = key.dims[2]; } } else { // packed QKV - if (query.dims.length !== 3 && query.dims.length !== 5) { - throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); + if (query.dims.length !== 5) { + throw new Error('Input "query" is expected to have 5 dimensions when key is empty'); } - if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) { + if (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3) { throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); } @@ -146,13 +153,15 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr throw new Error('Input "bias" is expected to have 1 dimension'); } - if (value) { - if (query.dims.length === 5 && query.dims[3] === 2) { + if (key) { + if (key.dims.length === 5 && key.dims[3] === 2) { throw new Error('bias is not allowed for packed kv.'); } } } + const totalSequenceLength = pastSequenceLength + kvSequenceLength; + let maskType: AttentionMaskType = AttentionMaskType.none; if (keyPaddingMask) { maskType = AttentionMaskType.maskUnknown; @@ -163,11 +172,11 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr } else if (maskDims[0] === 3 * batchSize + 2) { maskType = AttentionMaskType.mask1DKeySeqLenStart; } - } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) { + } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === totalSequenceLength) { maskType = AttentionMaskType.mask2dKeyPadding; } if (maskType === AttentionMaskType.maskUnknown) { - throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, kv_sequence_length)'); + throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, total_sequence_length)'); } throw new Error('Mask not supported'); } @@ -188,30 +197,32 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)'); } vHiddenSize = value.dims[2]; - } else { + } else { // Q_K_V_BSNH_BNSH_BNSH if (kvSequenceLength !== value.dims[2]) { - throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)'); + throw new Error('Input "key" and "value" shall have the same dim 2 (kv_sequence_length)'); } vHiddenSize = value.dims[1] * value.dims[3]; passPastInKv = true; } } - const totalSequenceLength = pastSequenceLength + kvSequenceLength; const broadcastResPosBias = false; if (keyPaddingMask) { throw new Error('Key padding mask is not supported'); } - if (relativePositionBias) { - if (relativePositionBias.dims.length !== 4) { - throw new Error('Input "relative_position_bias" is expected to have 4 dimensions'); + if (attentionBias) { + if (attentionBias.dims.length !== 4) { + throw new Error('Input "attention_bias" is expected to have 4 dimensions'); } - if ((relativePositionBias.dims[0] !== batchSize && relativePositionBias.dims[0] !== 1) || - relativePositionBias.dims[1] !== attributes.numHeads || relativePositionBias.dims[2] !== sequenceLength || - relativePositionBias.dims[3] !== totalSequenceLength) { - throw new Error('Input "relative_position_bias" shape (batch_size, 1, sequence_length, kv_sequence_length)'); + + // TODO: support broadcasting the first and second dimensions of attention_bias. + if (attentionBias.dims[0] !== batchSize || + attentionBias.dims[1] !== attributes.numHeads || + attentionBias.dims[2] !== sequenceLength || + attentionBias.dims[3] !== totalSequenceLength) { + throw new Error('Input "attention_bias" shape shall be (batch_size, num_heads, sequence_length, total_sequence_length)'); } } @@ -320,7 +331,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio const value = getInput(context.inputs, 2); const bias = getInput(context.inputs, 3); const keyPaddingMask = getInput(context.inputs, 4); - const relativePositionBias = getInput(context.inputs, 5); + const attentionBias = getInput(context.inputs, 5); const pastKey = getInput(context.inputs, 6); const pastValue = getInput(context.inputs, 7); if (query.dims.length === 5) { @@ -339,7 +350,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio if (kvBNSH) { return applyAttention( - context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, relativePositionBias, params, + context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes); } if (!key || !value) { @@ -354,5 +365,5 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio 2 * params.hiddenSize); applyAttention( - context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, relativePositionBias, params, attributes); + context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes); }; diff --git a/js/web/test/data/ops/multihead-attention.jsonc b/js/web/test/data/ops/multihead-attention.jsonc index 6ce6a5e0a8ce6..ed937a22c0b84 100644 --- a/js/web/test/data/ops/multihead-attention.jsonc +++ b/js/web/test/data/ops/multihead-attention.jsonc @@ -228,7 +228,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -293,7 +293,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -322,7 +322,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=1 with optional RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs", + "name": "MultiHeadAttention Basic, one head and head-size=1 with optional AttentionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -358,7 +358,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -397,7 +397,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs", + "name": "MultiHeadAttention Basic, one head and head-size=4 with attentionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -433,7 +433,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -474,7 +474,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=1 with relativePositionBias, pastKey and pastValue", + "name": "MultiHeadAttention Basic, one head and head-size=1 with attentionBias, pastKey and pastValue", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -510,7 +510,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [10, 20], "dims": [1, 1, 1, 2], @@ -540,7 +540,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=4 with relativePositionBias, and pastValue", + "name": "MultiHeadAttention Basic, one head and head-size=4 with attentionBias, and pastValue", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -576,7 +576,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [100, 200], "dims": [1, 1, 1, 2], @@ -642,7 +642,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -717,7 +717,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": null, "type": "float32" @@ -767,7 +767,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size one with RelativePositionBias, pastKey, pastValue, presentKey and presentValue", + "name": "MultiHeadAttention Basic, one head and head-size one with attentionBias, pastKey, pastValue, presentKey and presentValue", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -803,7 +803,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [10, 20], "dims": [1, 1, 1, 2], @@ -843,7 +843,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue inputs and PresentKey and PresentValue outputs", + "name": "MultiHeadAttention Basic, one head and head-size=4 with attentionBias, PastKey, PastValue inputs and PresentKey and PresentValue outputs", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -879,7 +879,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [100, 200], "dims": [1, 1, 1, 2], @@ -957,7 +957,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [10, 20], "dims": [1, 1, 1, 2], @@ -1033,7 +1033,7 @@ "data": null, "type": "int32" }, - // RelativePositionBias + // AttentionBias { "data": [50, 100], "dims": [1, 1, 1, 2], diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index f7d8fedc734e4..21e4b4c7932bc 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "core/providers/common.h" namespace onnxruntime { @@ -12,7 +13,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const Tensor* past_seq_len) const { // Abbreviation and Meanings: @@ -37,7 +38,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // bias (Q/K/V) : (D + D + D_v) // mask_index : see below // past (K/V) : (2, B, N, P, H) or NULL - // relative_position_bias : (B, N, S, T) or NULL + // attention_bias : (B, N, S, T) or NULL // For mask_index, the following shapes are supported: // NULL, (B, 1), (1, 1) @@ -49,9 +50,9 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger // than hidden dimension of Q, K and V. - if (past != nullptr && relative_position_bias != nullptr) { - // past is used on GPT-2 model with past state, we don't have a case for relative position bias yet - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and relative_position_bias"); + if (past != nullptr && attention_bias != nullptr) { + // past is used on GPT-2 model with past state, we don't have a case for attention bias yet + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and attention_bias"); } const auto& dims = input_shape.GetDims(); @@ -191,39 +192,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, } } - bool broadcast_res_pos_bias = false; - if (relative_position_bias != nullptr) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); + gsl::span attention_bias_dims; + if (attention_bias != nullptr) { + attention_bias_dims = attention_bias->Shape().GetDims(); - if (relative_position_bias_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); - } - - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be same as batch_size or 1, got ", - relative_position_bias_dims[0]); - } - if (relative_position_bias_dims[0] == 1) { - broadcast_res_pos_bias = true; - } - if (relative_position_bias_dims[1] != num_heads_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); - } - if (relative_position_bias_dims[2] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); - } - if (relative_position_bias_dims[3] != total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", - relative_position_bias_dims[3]); - } + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckAttentionBias( + attention_bias_dims, batch_size, num_heads_, sequence_length, total_sequence_length)); } if (past != nullptr && past_present_share_buffer_) { @@ -257,7 +231,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->mask_filter_value = mask_filter_value_; output_parameters->scale = scale_; output_parameters->mask_type = mask_type; - output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; + output_parameters->attention_bias_dims = attention_bias_dims; output_parameters->qkv_format = Q_K_V_BNSH; } @@ -329,7 +303,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) const { @@ -337,7 +311,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, relative_position_bias, parameters, past_seq_len); + return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, attention_bias, parameters, past_seq_len); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index a6782daa58f1a..05756cd54d842 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -18,7 +18,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const int max_threads_per_block, // for CUDA const Tensor* past_seq_len = nullptr) const; @@ -63,7 +63,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr. const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const Tensor* past_seq_len = nullptr) const; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 88127387d08ea..6ea293ea3a870 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once +#include namespace onnxruntime { namespace contrib { @@ -68,7 +69,7 @@ struct AttentionParameters { bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; - bool broadcast_res_pos_bias; + gsl::span attention_bias_dims; float mask_filter_value; float scale; bool use_tf32; @@ -88,8 +89,7 @@ struct PackedAttentionParameters { int num_heads; float scale; int token_count; - bool has_relative_position_bias; - bool broadcast_res_pos_bias; + gsl::span attention_bias_dims; bool use_tf32; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index dd52001c2ac6b..a49cf60655a49 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -19,23 +19,23 @@ class AttentionCPUBase : public AttentionBase { : AttentionBase(info, require_same_hidden_size) {} template - Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH - const T* K, // K data with shape BxNxLxH - const T* V, // V value with size BxNxLxH_v - const Tensor* mask_index, // mask index. nullptr if no mask or its size is B - const Tensor* past, // past state - const Tensor* past_key, // past K input tensor (if not using past state) - const Tensor* past_value, // past V input tensor (if not using past state) - Tensor* output, // output tensor - Tensor* present_key, // present K output tensor (if separating present KV) - Tensor* present_value, // present V output tensor (if separating present KV) - int batch_size, // batch size (B) - int sequence_length, // sequence length of Q (S) - int kv_sequence_length, // sequence length of K or V (L) - int qk_head_size, // head size of Q or K (H) - int v_head_size, // head size of V (H_v) - int v_hidden_size, // hidden size of V (D_v) - const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT + Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxLxH + const T* V, // V value with size BxNxLxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past, // past state + const Tensor* past_key, // past K input tensor (if not using past state) + const Tensor* past_value, // past V input tensor (if not using past state) + Tensor* output, // output tensor + Tensor* present_key, // present K output tensor (if separating present KV) + Tensor* present_value, // present V output tensor (if separating present KV) + int batch_size, // batch size (B) + int sequence_length, // sequence length of Q (S) + int kv_sequence_length, // sequence length of K or V (L) + int qk_head_size, // head size of Q or K (H) + int v_head_size, // head size of V (H_v) + int v_hidden_size, // hidden size of V (D_v) + const Tensor* attn_bias, // additive bias applied on QK. Its size is BxNxSxT or 1xNxSxT OpKernelContext* context) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -66,10 +66,14 @@ class AttentionCPUBase : public AttentionBase { gsl::span mask_index_dims = mask_index != nullptr ? mask_index->Shape().GetDims() : gsl::span{}; + DUMP_CPU_TENSOR_INIT(); + DUMP_CPU_TENSOR("Mask", mask_index_data, mask_index_dims); + if (mask_data != nullptr) { + // Convert mask from boolean (0/1) to float (mask_filter_value/0.0f). + // Merge padding mask with causual mask, and broadcast to 3D (BxSxT). PrepareMask(mask_index_data, mask_index_dims, static_cast(mask_data), causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_); - DUMP_CPU_TENSOR_INIT(); DUMP_CPU_TENSOR("Mask3D", static_cast(mask_data), batch_size, sequence_length, total_sequence_length); } @@ -82,10 +86,8 @@ class AttentionCPUBase : public AttentionBase { const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; - const T* relative_position_bias_data = nullptr; - if (relative_position_bias != nullptr) { - relative_position_bias_data = relative_position_bias->Data(); - } + const T* attn_bias_data = (attn_bias != nullptr) ? attn_bias->Data() : nullptr; + auto attn_bias_shape = (attn_bias != nullptr) ? attn_bias->Shape().GetDims() : gsl::span{}; // Compute the attention score. size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T); @@ -95,7 +97,7 @@ class AttentionCPUBase : public AttentionBase { static_cast(mask_data), batch_size, sequence_length, kv_sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, - present_data, present_key_data, tp, scale, relative_position_bias_data); + present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_shape); // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) auto out_tmp_data = @@ -115,29 +117,34 @@ class AttentionCPUBase : public AttentionBase { // 1 x mask_data(B, N, S, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - T* mask_data, // buffer for mask data. - int batch_size, // batch size of self-attention - int sequence_length, // sequence length of self-attention (S) - int kv_sequence_length, // sequence length of cross-attention (L) - int past_sequence_length, // sequence length of past state - int head_size, // head size of self-attention - const T* past, // past state - const T* past_key, // past key only (if not using past state) - T* present, // present state - T* present_key, // present key only (if not using present state) - ThreadPool* tp, // thread pool - float scale, // scale factor - const T* relative_position_bias_data // bias addition matrix with shape BxNxSxT - ) const { + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + T* mask_data, // buffer for mask data. + int batch_size, // batch size of self-attention + int sequence_length, // sequence length of self-attention (S) + int kv_sequence_length, // sequence length of cross-attention (L) + int past_sequence_length, // sequence length of past state + int head_size, // head size of self-attention + const T* past, // past state + const T* past_key, // past key only (if not using past state) + T* present, // present state + T* present_key, // present key only (if not using present state) + ThreadPool* tp, // thread pool + float scale, // scale factor + const T* attn_bias_data, // bias addition matrix with shape BxNxSxT or 1xNxSxT + gsl::span attn_bias_shape) const { const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // P x H const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H const size_t kv_input_chunk_length = static_cast(kv_sequence_length) * head_size; // L x H const size_t present_chunk_length = past_chunk_length + kv_input_chunk_length; // T x H + DUMP_CPU_TENSOR_INIT(); + DUMP_CPU_TENSOR("Q", Q, batch_size, num_heads_, sequence_length, head_size); + DUMP_CPU_TENSOR("K", K, batch_size, num_heads_, total_sequence_length, head_size); + DUMP_CPU_TENSOR("Attn_Bias", attn_bias_data, attn_bias_shape); + { const int loop_len = batch_size * num_heads_; const float alpha = scale; @@ -160,7 +167,7 @@ class AttentionCPUBase : public AttentionBase { unit_cost.bytes_stored += bytes_to_copy_key; } - if (relative_position_bias_data != nullptr) { + if (attn_bias_data != nullptr) { unit_cost.compute_cycles += static_cast(sequence_length * total_sequence_length); unit_cost.bytes_loaded += probs_matrix_bytes * 2; unit_cost.bytes_stored += probs_matrix_bytes; @@ -169,13 +176,35 @@ class AttentionCPUBase : public AttentionBase { ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast(i) / num_heads_; + const std::ptrdiff_t head_index = i % static_cast(num_heads_); const ptrdiff_t output_offset = SafeInt(i) * sequence_length * total_sequence_length; const ptrdiff_t mask_offset = SafeInt(batch_index) * sequence_length * total_sequence_length; + + ptrdiff_t attn_bias_offset = 0; + if (attn_bias_data != nullptr) { + // broadcast of batch dim with shape (1, N or 1, S, T) + if (attn_bias_shape[0] != 1) { + attn_bias_offset += SafeInt(batch_index) * num_heads_ * sequence_length * total_sequence_length; + } + + // broadcast of head dim with shape (B or 1, 1, S, T) + if (attn_bias_shape[1] != 1) { + attn_bias_offset += head_index * sequence_length * total_sequence_length; + } + } + T* output = attention_probs + output_offset; - // Broadcast mask data: (Bx)SxT -> (BxNx)SxT - if (mask_data != nullptr) { + if (attn_bias_data != nullptr) { + memcpy(output, attn_bias_data + attn_bias_offset, probs_matrix_bytes); + if (mask_data != nullptr) { + for (int j = 0; j < sequence_length * total_sequence_length; j++) { + output[j] += mask_data[mask_offset + j]; + } + } + } else if (mask_data != nullptr) { + // Broadcast mask data: (Bx)SxT -> (BxNx)SxT memcpy(output, mask_data + mask_offset, probs_matrix_bytes); } @@ -195,18 +224,10 @@ class AttentionCPUBase : public AttentionBase { math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha, Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, output, nullptr); - - if (relative_position_bias_data != nullptr) { - for (int j = 0; j < sequence_length * total_sequence_length; j++) { - output[j] += relative_position_bias_data[output_offset + j]; - } - } } }); } - DUMP_CPU_TENSOR_INIT(); - DUMP_CPU_TENSOR("Q", Q, batch_size, num_heads_, sequence_length, head_size); DUMP_CPU_TENSOR("QK (scaled)", attention_probs, batch_size, num_heads_, sequence_length, total_sequence_length); // attention_probs(B, N, S, T) = Softmax(attention_probs) diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 0d77376779230..ca818f09c4b1e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -57,7 +57,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { const Tensor* value = context->Input(2); const Tensor* bias = context->Input(3); const Tensor* key_padding_mask = context->Input(4); - const Tensor* extra_add_qk = context->Input(5); + const Tensor* attn_bias = context->Input(5); const Tensor* past_key = context->Input(6); const Tensor* past_value = context->Input(7); @@ -75,7 +75,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { value, bias, key_padding_mask, - extra_add_qk, + attn_bias, past_key, past_value, nullptr, @@ -135,7 +135,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { value->Data(), key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v, batch_size, q_sequence_length, kv_sequence_length, - qk_head_size, v_head_size, v_hidden_size, extra_add_qk, context); + qk_head_size, v_head_size, v_hidden_size, attn_bias, context); } OrtValue K; @@ -149,7 +149,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { !disable_flash_ && !is_unidirectional_ && key_padding_mask == nullptr && - extra_add_qk == nullptr && + attn_bias == nullptr && past_key == nullptr && past_value == nullptr && present_k == nullptr && @@ -215,7 +215,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { V.GetMutable()->MutableData(), key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v, batch_size, q_sequence_length, kv_sequence_length, - qk_head_size, v_head_size, v_hidden_size, extra_add_qk, context); + qk_head_size, v_head_size, v_hidden_size, attn_bias, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index cfb8d36843777..85f223f8ec7a4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -179,39 +179,35 @@ Status CheckPast(const T* past_key, const T* past_value, const T* past_seq_len, return Status::OK(); } -template -Status CheckRelativePositionBias( - const T* relative_position_bias, int batch_size, int num_heads, int sequence_length, int total_sequence_length, - bool& broadcast_res_pos_bias) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - - if (relative_position_bias_dims.size() != 4) { +inline Status CheckAttentionBias( + const gsl::span& attention_bias_dims, + int batch_size, int num_heads, int sequence_length, int total_sequence_length) { + if (attention_bias_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); + "Input 'attention_bias' is expected to have 4 dimensions, got ", + attention_bias_dims.size()); } - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { + + if (attention_bias_dims[0] != batch_size && attention_bias_dims[0] != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", - relative_position_bias_dims[0]); + "Input 'attention_bias' dimension 0 should be batch_size or 1, got ", + attention_bias_dims[0]); } - if (relative_position_bias_dims[0] == 1) { - broadcast_res_pos_bias = true; - } - if (relative_position_bias_dims[1] != num_heads) { + + if (attention_bias_dims[1] != num_heads && attention_bias_dims[1] != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); + "Input 'attention_bias' dimension 1 should be same as number of heads or 1, got ", + attention_bias_dims[1]); } - if (relative_position_bias_dims[2] != sequence_length) { + if (attention_bias_dims[2] != sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); + "Input 'attention_bias' dimension 2 should be same as sequence_length, got ", + attention_bias_dims[2]); } - if (relative_position_bias_dims[3] != total_sequence_length) { + if (attention_bias_dims[3] != total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", - relative_position_bias_dims[3]); + "Input 'attention_bias' dimension 3 should be same as total_sequence_length, got ", + attention_bias_dims[3]); } return Status::OK(); } @@ -243,7 +239,7 @@ Status CheckInputs(const T* query, const T* value, const T* bias, const T* key_padding_mask, - const T* relative_position_bias, + const T* attention_bias, const T* past_key, const T* past_value, const T* past_seq_len, @@ -258,13 +254,15 @@ Status CheckInputs(const T* query, // Notations: // B: batch_size // N: num_heads - // H: head_size (V might have different head size than Q and K) - // D: hidden_size = N * H + // H: head_size of Q and K. + // H_v: head_size of V. + // D: hidden_size of Q and K, where D = N * H + // D_v: hidden_size of V, where D_v = N * H_v // S: q_sequence_length - // P: past_sequence_length + // P: past_sequence_length of kv cache // L: kv_sequence_length // T: total_sequence_length = P + L - // M: max_sequence_length + // M: max_sequence_length of kv cache when past and present share buffer // --------------------------------------------------------------- // MultiHeadAttention inputs: // --------------------------------------------------------------- @@ -275,7 +273,7 @@ Status CheckInputs(const T* query, // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v): // query (Q) : (B, S, D) // key (K) : (B, N, L, H) - // value (V) : (B, N, L, H) + // value (V) : (B, N, L, H_v) // Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv): // query (Q) : (B, S, D) // key (K/V) : (B, L, N, 2, H) @@ -288,7 +286,7 @@ Status CheckInputs(const T* query, // Other inputs: // bias (Q/K/V) : None or (D + D + D_v) // key_padding_mask (K/V) : (B) or (3 * B + 2) or (B, T) or (B, S, T) - // relative_position_bias : (B, N, S, T) or (1, N, S, T) + // attention_bias : (B, N, S, T), (1, N, S, T), (B, 1, S, T) or (1, 1, S, T) // past_key : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. // past_value : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. // --------------------------------------------------------------- @@ -298,7 +296,7 @@ Status CheckInputs(const T* query, // query (Q) : (B, S, D) // key (K) : (B, L, D) // value (V) : (B, L, D) - // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache and relative_position_bias are not used. L == T): + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache and attention_bias are not used. L == T): // query (Q) : (B, S, D) // key (K) : (B, N, L, H) // value (V) : (B, N, L, H) @@ -310,7 +308,7 @@ Status CheckInputs(const T* query, // Other inputs: // bias (Q/K/V) : None or (3 * D) // key_padding_mask (K/V) : None or (B, T) - // relative_position_bias : (1, N, S, T), or (B, N, S, T) where only 1 x N x S x T data is used in CUDA. + // attention_bias : (1, N, S, T), or (B, N, S, T) where only 1 x N x S x T data is used in CUDA. // // The following inputs are not used in cross attention (so they are None for cross attention): // past_key : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. @@ -401,10 +399,11 @@ Status CheckInputs(const T* query, } } - bool broadcast_res_pos_bias = false; - if (relative_position_bias != nullptr) { - ORT_RETURN_IF_ERROR(CheckRelativePositionBias( - relative_position_bias, batch_size, num_heads, sequence_length, total_sequence_length, broadcast_res_pos_bias)); + gsl::span attention_bias_dims; + if (attention_bias != nullptr) { + attention_bias_dims = attention_bias->Shape().GetDims(); + ORT_RETURN_IF_ERROR(CheckAttentionBias( + attention_bias_dims, batch_size, num_heads, sequence_length, total_sequence_length)); } assert(qkv_format != UNKNOWN); @@ -428,7 +427,7 @@ Status CheckInputs(const T* query, output_parameters->mask_filter_value = mask_filter_value; output_parameters->mask_type = mask_type; output_parameters->scale = scale; - output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; + output_parameters->attention_bias_dims = attention_bias_dims; output_parameters->qkv_format = qkv_format; } @@ -441,7 +440,7 @@ Status CheckInputs(const T* query, const T* value, const T* bias, const T* key_padding_mask, - const T* relative_position_bias, + const T* attention_bias, const T* past_key, const T* past_value, const T* past_seq_len, @@ -457,7 +456,7 @@ Status CheckInputs(const T* query, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, + return CheckInputs(query, key, value, bias, key_padding_mask, attention_bias, past_key, past_value, past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, past_present_share_buffer, operator_type); } diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h index 2782a59d4326d..ff7921fc70da3 100644 --- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h +++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h @@ -32,6 +32,11 @@ class IConsoleDumper { virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; + virtual void Print(const char* name, const int32_t* tensor, gsl::span& dims) const = 0; + virtual void Print(const char* name, const int64_t* tensor, gsl::span& dims) const = 0; + virtual void Print(const char* name, const float* tensor, gsl::span& dims) const = 0; + virtual void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const = 0; + virtual void Print(const char* name, const Tensor& value) const = 0; virtual void Print(const char* name, const OrtValue& value) const = 0; virtual void Print(const char* name, int index, bool end_line) const = 0; @@ -43,5 +48,38 @@ class IConsoleDumper { bool is_enabled_; }; +template +void PrintTensorByDims(const TConsoleDumper* dumper, + const char* name, + const T* tensor, + gsl::span& dims) { + if (dumper->IsEnabled && (tensor == nullptr || dims.size() == 0)) { + std::cout << std::string(name) << " is None" << std::endl; + return; + } + + auto num_dims = dims.size(); + if (num_dims == 1) { + dumper->Print(name, tensor, 1, static_cast(dims[0])); + } else if (num_dims == 2) { + dumper->Print(name, tensor, static_cast(dims[0]), static_cast(dims[1])); + } else if (num_dims == 3) { + dumper->Print(name, tensor, static_cast(dims[0]), static_cast(dims[1]), static_cast(dims[2])); + } else if (num_dims == 4) { + dumper->Print(name, tensor, + static_cast(dims[0]), + static_cast(dims[1]), + static_cast(dims[2]), + static_cast(dims[3])); + } else if (num_dims == 5) { + dumper->Print(name, tensor, + static_cast(dims[0]) * static_cast(dims[1]), + static_cast(dims[2]), + static_cast(dims[3]), + static_cast(dims[4])); + } else { + ORT_ENFORCE(false, "Unsupported tensor dims"); + } +} } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc index 87a9cd3965763..7755f9505d99d 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc @@ -246,7 +246,24 @@ void CpuTensorConsoleDumper::Print(const char* name, const std::string& value, b } } +void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + #else + CpuTensorConsoleDumper::CpuTensorConsoleDumper() { } @@ -303,6 +320,18 @@ void CpuTensorConsoleDumper::Print(const char*, int, bool) const { void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const { } + +void CpuTensorConsoleDumper::Print(const char*, const int32_t*, gsl::span&) const { +} + +void CpuTensorConsoleDumper::Print(const char*, const int64_t*, gsl::span&) const { +} + +void CpuTensorConsoleDumper::Print(const char*, const float*, gsl::span&) const { +} + +void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, gsl::span&) const { +} #endif } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h index f102eae6ec709..6fc4dfd4a0671 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h @@ -30,6 +30,11 @@ class CpuTensorConsoleDumper : public IConsoleDumper { void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; + void Print(const char* name, const int32_t* tensor, gsl::span& dims) const override; + void Print(const char* name, const int64_t* tensor, gsl::span& dims) const override; + void Print(const char* name, const float* tensor, gsl::span& dims) const override; + void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const override; + void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; void Print(const char* name, int index, bool end_line) const override; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 5c0989bced70c..1d1416995a673 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -59,7 +59,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(kPastInputIndex); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); auto& device_prop = GetDeviceProp(); @@ -74,7 +74,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bias != nullptr ? bias->Shape() : bias_shape, mask_index, past, - relative_position_bias, + attention_bias, ¶meters, device_prop.maxThreadsPerBlock, past_seq_len)); @@ -104,7 +104,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && - (nullptr == relative_position_bias) && + (nullptr == attention_bias) && nullptr == past && nullptr == present && parameters.hidden_size == parameters.v_hidden_size && @@ -146,7 +146,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // where past state is empty. bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == relative_position_bias && + nullptr == attention_bias && parameters.past_sequence_length == 0 && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, @@ -169,7 +169,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { (nullptr == mask_index || is_mask_1d_seq_len) && nullptr == past && nullptr == present && - nullptr == relative_position_bias && + nullptr == attention_bias && parameters.hidden_size == parameters.v_hidden_size && FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, enable_trt_flash_attention_, false); @@ -201,12 +201,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == present && (nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && + (nullptr == attention_bias || parameters.sequence_length % (4 * sizeof(T)) == 0) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); - if (use_memory_efficient_attention) { - bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; - use_memory_efficient_attention = (nullptr == relative_position_bias || is_good_for_rpb); - } #else constexpr bool use_memory_efficient_attention = false; #endif @@ -277,8 +274,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (nullptr != past) { data.past = reinterpret_cast(past->Data()); } - if (nullptr != relative_position_bias) { - data.relative_position_bias = reinterpret_cast(relative_position_bias->Data()); + if (nullptr != attention_bias) { + data.attention_bias = reinterpret_cast(attention_bias->Data()); } data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index f9eabe27d97e4..5508388e99257 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -290,7 +290,7 @@ Status FlashAttention( assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); assert(nullptr == data.mask_index); - assert(nullptr == data.relative_position_bias); + assert(nullptr == data.attention_bias); assert(parameters.head_size == parameters.v_head_size); constexpr bool is_bf16 = false; @@ -332,6 +332,8 @@ Status EfficientAttention( // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + assert(parameters.mask_type == AttentionMaskType::MASK_NONE || + parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -345,22 +347,22 @@ Status EfficientAttention( p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; p.scale = scale; - p.seqlen_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index)); - p.seqstart_q_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast( - data.mask_index + parameters.batch_size)); - p.seqstart_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast( - data.mask_index + 2 * parameters.batch_size + 1)); + + if (nullptr == data.mask_index) { + p.seqlen_k_ptr = nullptr; + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + } else { + p.seqlen_k_ptr = const_cast(reinterpret_cast(data.mask_index)); + p.seqstart_q_ptr = const_cast(reinterpret_cast(data.mask_index + parameters.batch_size)); + p.seqstart_k_ptr = const_cast(reinterpret_cast(data.mask_index + 2 * parameters.batch_size + 1)); + } + p.query = data.q; p.key = data.k; p.value = data.v; - p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; - p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; + p.attn_bias = (nullptr == data.attention_bias) ? nullptr : data.attention_bias; + p.attn_bias_dims = data.attention_bias_dims; p.output = data.output; p.is_kv_bsnh = data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) @@ -415,6 +417,12 @@ Status UnfusedAttention( const int present_size_per_batch_k = present_sequence_length * qk_head_size; const int present_size_per_batch_v = present_sequence_length * v_head_size; + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k", data.k, batch_size, num_heads, total_sequence_length, qk_head_size); + DUMP_TENSOR_D("v", data.v, batch_size, num_heads, total_sequence_length, v_head_size); + DUMP_TENSOR_D("mask_index", mask_index, mask_index_dims); + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, total_sequence_length, sequence_length, qk_head_size, @@ -423,7 +431,6 @@ Status UnfusedAttention( &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop, parameters.use_tf32)); - DUMP_TENSOR_INIT(); DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); constexpr size_t element_size = sizeof(T); @@ -431,6 +438,9 @@ Status UnfusedAttention( sequence_length, total_sequence_length); T* scratch2 = data.scratch + (bytes / element_size); + bool broadcast_attn_bias_dim_0 = parameters.attention_bias_dims.size() > 0 && parameters.attention_bias_dims[0] == 1; + bool broadcast_attn_bias_dim_1 = parameters.attention_bias_dims.size() > 1 && parameters.attention_bias_dims[1] == 1; + // Apply softmax and store result R to scratch2: BxNxSxT if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask const int mask_dimension = static_cast(mask_index_dims.size()); @@ -444,7 +454,7 @@ Status UnfusedAttention( ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask( ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias, + mask_index, nullptr, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, parameters.mask_filter_value)); @@ -454,17 +464,17 @@ Status UnfusedAttention( const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, mask_start, data.relative_position_bias, parameters.broadcast_res_pos_bias, + mask_index, mask_start, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, data.scratch, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( ComputeSoftmax( - stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, - parameters.broadcast_res_pos_bias, data.scratch, scratch2, parameters.is_unidirectional)); + stream, total_sequence_length, sequence_length, batch_size, num_heads, + data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + data.scratch, scratch2, parameters.is_unidirectional)); } DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); - DUMP_TENSOR_D("V", data.v, batch_size, num_heads, sequence_length, v_head_size); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v T* temp_output = data.q; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index fad353dcfeb07..29b6c1f53a7e3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -69,7 +69,8 @@ struct AttentionData { const T* past = nullptr; const T* past_key = nullptr; const T* past_value = nullptr; - const T* relative_position_bias = nullptr; + const T* attention_bias = nullptr; + gsl::span attention_bias_dims; bool has_qkv_workspace = false; T* workspace = nullptr; @@ -115,7 +116,7 @@ struct AttentionData { << ", fused_runner=" << (fused_runner != nullptr) << ", fused_cross=" << (fused_cross_attention_kernel != nullptr) << ", bias=" << (bias != nullptr) - << ", attn_bias=" << (relative_position_bias != nullptr) + << ", attn_bias=" << (attention_bias != nullptr) << ", mask_dims=" << mask_index_dims.size() << ", has_qkv_workspace=" << has_qkv_workspace << ", workspace=" << workspace_bytes diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 05c592ec61059..d34e6a92bab03 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -77,10 +77,8 @@ void DumpInputs(contrib::AttentionParameters& parameters, AttentionData& data DUMP_TENSOR_D("V_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); } - if (data.relative_position_bias != nullptr) { - DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, - parameters.broadcast_res_pos_bias ? 1 : batch_size, - num_heads, sequence_length, kv_sequence_length); + if (data.attention_bias != nullptr) { + DUMP_TENSOR_D("attention_bias", data.attention_bias, parameters.attention_bias_dims); } if (data.mask_index != nullptr) { @@ -258,7 +256,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, if (data.fused_cross_attention_kernel != nullptr) { assert(qk_head_size == v_head_size); - assert(data.relative_position_bias == nullptr); + assert(data.attention_bias == nullptr); assert(data.mask_index == nullptr); assert(parameters.hidden_size == parameters.v_hidden_size); @@ -290,7 +288,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, #endif else if (data.fused_runner != nullptr) { assert(qk_head_size == v_head_size); - assert(data.relative_position_bias == nullptr); + assert(data.attention_bias == nullptr); // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H) LaunchAddBiasTransposeTrt( @@ -524,7 +522,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, true, v_head_size, qkv_add_bias, 3); data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; } else if (nullptr != data.fused_runner) { - assert(nullptr == data.relative_position_bias); + assert(nullptr == data.attention_bias); if (data.bias == nullptr) { // When there is no bias, we can directly use the original packed QKV input. // Need revisit this when we add support for causal. diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index 01ea02f48d3ab..494e708a85485 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -29,12 +29,14 @@ namespace onnxruntime { namespace contrib { namespace attention_softmax_cuda { +// This kernel is for non causal, attention mask 1D or None, and total_sequence_length > 1024. template -__device__ inline void Softmax(const int all_sequence_length, +__device__ inline void Softmax(const int total_sequence_length, const int valid_end, const int valid_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output) { using BlockReduce = cub::BlockReduce; @@ -45,28 +47,42 @@ __device__ inline void Softmax(const int all_sequence_length, float thread_data_max(-CUDART_INF_F); - const bool no_rpb = (rel_pos_bias == nullptr); + // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is partitioned to blocks by TPB. + const int sequence_length = gridDim.x; + const int num_heads = gridDim.y; + const int batch_size = gridDim.z; + const int s = blockIdx.x; + const int n = blockIdx.y; + const int b = blockIdx.z; + + // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) + int block = b * num_heads * sequence_length + n * sequence_length + s; + const int64_t offset = static_cast(block) * static_cast(total_sequence_length); + + const bool has_bias = (attn_bias != nullptr); + int64_t bias_offset = 0; + if (has_bias) { + // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) + block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + + (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + + s; + bias_offset = static_cast(block) * static_cast(total_sequence_length); + } // e^x is represented as infinity if x is large enough, like 100.f. // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - const int size_per_batch = gridDim.x * all_sequence_length; for (int i = threadIdx.x; i < valid_end; i += TPB) { if (i >= valid_start) { - const int index = offset + i; - float input_at_idx = no_rpb - ? float(input[index]) - : float(input[index] + (broadcast_rel_pos_bias - ? rel_pos_bias[index % size_per_batch] - : rel_pos_bias[index])); - if (thread_data_max < input_at_idx) { - thread_data_max = input_at_idx; + float input_data = has_bias + ? float(input[offset + i]) + float(attn_bias[bias_offset + i]) + : float(input[offset + i]); + if (thread_data_max < input_data) { + thread_data_max = input_data; } } } - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max()); // Store max value @@ -78,9 +94,11 @@ __device__ inline void Softmax(const int all_sequence_length, float thread_data_sum(0.f); for (int i = threadIdx.x; i < valid_end; i += TPB) { if (i >= valid_start) { - const int index = offset + i; - float val = no_rpb ? input[index] : input[index] + rel_pos_bias[index % size_per_batch]; - thread_data_sum += expf(val - max_block); + float input_data = has_bias + ? float(input[offset + i]) + float(attn_bias[bias_offset + i]) + : float(input[offset + i]); + + thread_data_sum += expf(input_data - max_block); } } @@ -90,21 +108,24 @@ __device__ inline void Softmax(const int all_sequence_length, } __syncthreads(); - for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { const int index = offset + i; - float input_at_idx = no_rpb ? float(input[index]) : float(input[index] + rel_pos_bias[index % size_per_batch]); - const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; + float input_data = has_bias + ? float(input[index]) + float(attn_bias[bias_offset + i]) + : float(input[index]); + const float val = (i >= valid_start && i < valid_end) ? expf(input_data - max_block) * sum_reverse_block : 0.f; output[index] = T(val); } } +// This kernel is for non causal, attention mask 1D or None, and total_sequence_length <= 1024. template -__device__ inline void SoftmaxSmall(const int all_sequence_length, - const int sequence_length, +__device__ inline void SoftmaxSmall(const int total_sequence_length, const int valid_end, const int valid_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, bool causal) { @@ -114,34 +135,49 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; + // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is within one block size TPB. + const int sequence_length = gridDim.x; + const int num_heads = gridDim.y; + const int batch_size = gridDim.z; + const int s = blockIdx.x; + const int n = blockIdx.y; + const int b = blockIdx.z; + + // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) + int block = b * num_heads * sequence_length + n * sequence_length + s; + const int64_t offset = static_cast(block) * static_cast(total_sequence_length); const int index = offset + threadIdx.x; + const bool has_bias = (attn_bias != nullptr); + int64_t bias_offset = 0; + if (has_bias) { + // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) + block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + + (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + + s; + bias_offset = static_cast(block) * static_cast(total_sequence_length); + } + // Update end position for causal. int end = valid_end; if (causal) { - const int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; + const int end_causal = total_sequence_length - sequence_length + s + 1; if (end_causal < end) { end = end_causal; } } const bool is_valid = (threadIdx.x >= valid_start && threadIdx.x < end); + float input_data = is_valid ? (has_bias + ? float(input[index]) + float(attn_bias[bias_offset + threadIdx.x]) + : float(input[index])) + : float(-CUDART_INF_F); // e^x is represented as infinity if x is large enough, like 100.f. // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const bool no_rpb = (rel_pos_bias == nullptr); - const int size_per_batch = gridDim.x * all_sequence_length; - float input_data = no_rpb - ? float(input[index]) - : float(input[index] + (broadcast_rel_pos_bias - ? rel_pos_bias[index % size_per_batch] - : rel_pos_bias[index])); - float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F); - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end); + const auto max = BlockReduce(tmp_storage).Reduce(input_data, cub::Max(), end); // Store max value if (threadIdx.x == 0) { @@ -162,23 +198,24 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length, } __syncthreads(); - // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. - if (threadIdx.x < all_sequence_length) { + // threadIdx.x might be larger than total_sequence_length due to alignment to 32x. + if (threadIdx.x < total_sequence_length) { output[index] = is_valid ? T(thread_data_exp * sum_reverse_block) : T(0.f); } } +// This kernel is for causal or not, attention mask 1D or None, and total_sequence_length <= 1024. template -__global__ void SoftmaxLargeKernel(const int all_sequence_length, - const int sequence_length, +__global__ void SoftmaxLargeKernel(const int total_sequence_length, const int valid_end, const int valid_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, bool causal) { - extern __shared__ float cached_data[]; // float[all_sequence_length] + extern __shared__ float cached_data[]; // float[total_sequence_length] using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -186,36 +223,46 @@ __global__ void SoftmaxLargeKernel(const int all_sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; + // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is partitioned by TPB. + const int sequence_length = gridDim.x; + const int num_heads = gridDim.y; + const int batch_size = gridDim.z; + const int s = blockIdx.x; + const int n = blockIdx.y; + const int b = blockIdx.z; + // Update end position for causal. int end = valid_end; if (causal) { - int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; + int end_causal = total_sequence_length - sequence_length + s + 1; if (end_causal < end) { end = end_causal; } } - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - const int size_per_batch = gridDim.x * all_sequence_length; + // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) + int block = b * num_heads * sequence_length + n * sequence_length + s; + const int64_t offset = static_cast(block) * static_cast(total_sequence_length); + + const bool has_bias = (attn_bias != nullptr); + int64_t bias_offset = 0; + if (has_bias) { + // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) + block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + + (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + + s; + bias_offset = static_cast(block) * static_cast(total_sequence_length); + } float thread_data_max = -CUDART_INF_F; - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { - const int index = offset + seq_idx; - const bool is_valid = (seq_idx >= valid_start && seq_idx < end); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - float input_data = is_valid - ? (rel_pos_bias - ? float(input[index] + (broadcast_rel_pos_bias - ? rel_pos_bias[index % size_per_batch] - : rel_pos_bias[index])) - : float(input[index])) - : float(-CUDART_INF_F); - cached_data[seq_idx] = input_data; + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { + const int index = offset + i; + const bool is_valid = (i >= valid_start && i < end); + float input_data = is_valid ? (has_bias + ? float(input[index]) + float(attn_bias[bias_offset + i]) + : float(input[index])) + : float(-CUDART_INF_F); + cached_data[i] = input_data; thread_data_max = max(thread_data_max, input_data); } const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end); @@ -227,10 +274,10 @@ __global__ void SoftmaxLargeKernel(const int all_sequence_length, __syncthreads(); float thread_data_exp(0.f); - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { - const bool is_valid = (seq_idx >= valid_start && seq_idx < end); - cached_data[seq_idx] = is_valid ? expf(cached_data[seq_idx] - max_block) : 0.0f; - thread_data_exp += cached_data[seq_idx]; + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { + const bool is_valid = (i >= valid_start && i < end); + cached_data[i] = is_valid ? expf(cached_data[i] - max_block) : 0.0f; + thread_data_exp += cached_data[i]; } const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end); @@ -240,20 +287,21 @@ __global__ void SoftmaxLargeKernel(const int all_sequence_length, } __syncthreads(); - // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { - const bool is_valid = (seq_idx >= valid_start && seq_idx < end); - output[offset + seq_idx] = is_valid ? T(cached_data[seq_idx] * sum_reverse_block) : T(0.f); + // threadIdx.x might be larger than total_sequence_length due to alignment to 32x. + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { + const bool is_valid = (i >= valid_start && i < end); + output[offset + i] = is_valid ? T(cached_data[i] * sum_reverse_block) : T(0.f); } } +// This kernel is for causal or not, raw attention mask (2D, 3D or 4D) and total_sequence_length > 1024. template -__global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, - const int sequence_length, +__global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length, const int* attention_mask, // 2D, 3D or 4D attention mask const bool* key_padding_mask, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal, @@ -262,7 +310,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, const int max_sequence_length, const bool skip_softmax, const float mask_filter_value) { - extern __shared__ float cached_data[]; // float[all_sequence_length] + extern __shared__ float cached_data[]; // float[total_sequence_length] using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -271,37 +319,54 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, __shared__ float max_block; float max_thread_data = -CUDART_INF_F; - const int size_per_batch = gridDim.x * all_sequence_length; + const int size_per_batch = gridDim.x * total_sequence_length; + + // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is partitioned by TPB. + const int sequence_length = gridDim.x; + const int num_heads = gridDim.y; + const int batch_size = gridDim.z; + const int s = blockIdx.x; + const int n = blockIdx.y; + const int b = blockIdx.z; + + // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) + int block = b * num_heads * sequence_length + n * sequence_length + s; + const int64_t offset = static_cast(block) * static_cast(total_sequence_length); + + const bool has_bias = (attn_bias != nullptr); + int64_t bias_offset = 0; + if (has_bias) { + // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) + block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + + (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + + s; + bias_offset = static_cast(block) * static_cast(total_sequence_length); + } - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - int base_index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { float thread_data = -CUDART_INF_F; - int index = base_index + seq_idx; - if (rel_pos_bias == nullptr) { + int index = offset + i; + if (attn_bias == nullptr) { thread_data = float(input[index]) * rsqrt_head_size; } else { - T rel_pos_bias_value = broadcast_rel_pos_bias ? rel_pos_bias[index % size_per_batch] : rel_pos_bias[index]; - thread_data = float(input[index] + rel_pos_bias_value) * rsqrt_head_size; + thread_data = (float(input[index]) + float(attn_bias[bias_offset + i])) * rsqrt_head_size; } - const int sequence_index = blockIdx.x % sequence_length; if (causal) { - int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. - if (seq_idx > from_index) { + int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length. + if (i > from_index) { thread_data = -CUDART_INF_F; } } int mask_offset = 0; - const int batch_index = blockIdx.y; if (mask_dimension == 2) { - mask_offset = batch_index * all_sequence_length + seq_idx; + mask_offset = b * total_sequence_length + i; } else if (mask_dimension == 3) { - mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + seq_idx; + mask_offset = (b * sequence_length + s) * total_sequence_length + i; } else if (mask_dimension == 4) { - int from_index = all_sequence_length - sequence_length + sequence_index; - mask_offset = (batch_index * max_sequence_length + from_index) * max_sequence_length + seq_idx; + int from_index = total_sequence_length - sequence_length + s; + mask_offset = (b * max_sequence_length + from_index) * max_sequence_length + i; } if (nullptr == key_padding_mask) { @@ -318,7 +383,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, if (skip_softmax) { output[index] = T(thread_data); } - cached_data[seq_idx] = thread_data; + cached_data[i] = thread_data; max_thread_data = max(max_thread_data, thread_data); } @@ -326,7 +391,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, return; } - const float max = BlockReduce(tmp_storage).Reduce(max_thread_data, cub::Max(), all_sequence_length); + const float max = BlockReduce(tmp_storage).Reduce(max_thread_data, cub::Max(), total_sequence_length); // Store max value if (threadIdx.x == 0) { @@ -335,9 +400,9 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, __syncthreads(); float sum_thread_data_exp = 0.0f; - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { - auto ev = expf(cached_data[seq_idx] - max_block); - cached_data[seq_idx] = ev; + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { + auto ev = expf(cached_data[i] - max_block); + cached_data[i] = ev; sum_thread_data_exp += ev; } const auto sum = BlockReduce(tmp_storage).Reduce(sum_thread_data_exp, cub::Sum(), TPB); @@ -348,18 +413,19 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length, } __syncthreads(); - for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) { - output[base_index + seq_idx] = T(cached_data[seq_idx] * sum_reverse_block); + for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { + output[offset + i] = T(cached_data[i] * sum_reverse_block); } } +// This kernel is for causal or not, raw attention mask (2D, 3D or 4D), and total_sequence_length <= 1024. template -__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, - const int sequence_length, +__device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length, const int* attention_mask, // 2D, 3D or 4D attention mask const bool* key_padding_mask, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal, @@ -374,31 +440,49 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; - const int size_per_batch = gridDim.x * all_sequence_length; + // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is within one block size TPB. + const int sequence_length = gridDim.x; + const int num_heads = gridDim.y; + const int batch_size = gridDim.z; + const int s = blockIdx.x; + const int n = blockIdx.y; + const int b = blockIdx.z; + + // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) + int block = b * num_heads * sequence_length + n * sequence_length + s; + const int64_t offset = static_cast(block) * static_cast(total_sequence_length); + + const bool has_bias = (attn_bias != nullptr); + int64_t bias_offset = 0; + if (has_bias) { + // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) + block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + + (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + + s; + bias_offset = static_cast(block) * static_cast(total_sequence_length); + } + + int64_t index = offset + threadIdx.x; float thread_data = -CUDART_INF_F; - if (threadIdx.x < all_sequence_length) { + if (threadIdx.x < total_sequence_length) { thread_data = float(input[index]) * rsqrt_head_size; - const int sequence_index = blockIdx.x % sequence_length; if (causal) { - int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. + int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length. if (threadIdx.x > from_index) { thread_data = -CUDART_INF_F; } } int mask_offset = 0; - const int batch_index = blockIdx.y; if (mask_dimension == 2) { - mask_offset = batch_index * all_sequence_length + threadIdx.x; + mask_offset = b * total_sequence_length + threadIdx.x; } else if (mask_dimension == 3) { - mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x; + mask_offset = (b * sequence_length + s) * total_sequence_length + threadIdx.x; } else if (mask_dimension == 4) { - int from_index = all_sequence_length - sequence_length + sequence_index; - mask_offset = (batch_index * max_sequence_length + from_index) * max_sequence_length + threadIdx.x; + int from_index = total_sequence_length - sequence_length + s; + mask_offset = (b * max_sequence_length + from_index) * max_sequence_length + threadIdx.x; } if (nullptr == key_padding_mask) { @@ -412,20 +496,19 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } } - if (rel_pos_bias != nullptr) { - float bias = broadcast_rel_pos_bias ? float(rel_pos_bias[index % size_per_batch]) : float(rel_pos_bias[index]); - thread_data += bias; + if (attn_bias != nullptr) { + thread_data += float(attn_bias[bias_offset + threadIdx.x]); } } if (skip_softmax) { - if (threadIdx.x < all_sequence_length) { + if (threadIdx.x < total_sequence_length) { output[index] = T(thread_data); } return; } - const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length); + const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), total_sequence_length); // Store max value if (threadIdx.x == 0) { @@ -433,8 +516,8 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } __syncthreads(); - float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), all_sequence_length); + float thread_data_exp = threadIdx.x < total_sequence_length ? expf(thread_data - max_block) : 0.0f; + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), total_sequence_length); // Store value of 1.0/sum if (threadIdx.x == 0) { @@ -442,71 +525,74 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } __syncthreads(); - if (threadIdx.x < all_sequence_length) { + if (threadIdx.x < total_sequence_length) { output[index] = T(thread_data_exp * sum_reverse_block); } } template -__global__ void SoftmaxKernelSmall(const int all_sequence_length, - const int sequence_length, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, +__global__ void SoftmaxKernelSmall(const int total_sequence_length, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, bool causal) { - SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); + SoftmaxSmall(total_sequence_length, total_sequence_length, 0, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); } template -__global__ void SoftmaxKernel(const int all_sequence_length, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, +__global__ void SoftmaxKernel(const int total_sequence_length, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output) { - Softmax(all_sequence_length, all_sequence_length, 0, - rel_pos_bias, broadcast_rel_pos_bias, input, output); + Softmax(total_sequence_length, total_sequence_length, 0, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } template -Status ComputeSoftmax(cudaStream_t stream, const int all_sequence_length, const int sequence_length, - const int batch_size, const int num_heads, const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, T* input, T* output, bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - if (all_sequence_length <= 32) { +Status ComputeSoftmax(cudaStream_t stream, const int total_sequence_length, const int sequence_length, + const int batch_size, const int num_heads, const T* attn_bias, + const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, + T* input, T* output, bool causal) { + const dim3 grid(sequence_length, num_heads, batch_size); + if (total_sequence_length <= 32) { const int blockSize = 32; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 64) { + total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 64) { const int blockSize = 64; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 128) { + total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 128) { const int blockSize = 128; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 256) { + total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 256) { const int blockSize = 256; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 512) { + total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 512) { const int blockSize = 512; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 1024) { + total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 1024) { const int blockSize = 1024; SoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); + total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); } else if (!causal) { const int blockSize = 1024; SoftmaxKernel<<>>( - all_sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output); + total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } else { const int blockSize = 256; - const int sh_bytes = sizeof(float) * all_sequence_length; + const int sh_bytes = sizeof(float) * total_sequence_length; SoftmaxLargeKernel<<>>( - all_sequence_length, sequence_length, all_sequence_length, 0, rel_pos_bias, broadcast_rel_pos_bias, + total_sequence_length, total_sequence_length, 0, attn_bias, + broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, true); } @@ -514,12 +600,12 @@ Status ComputeSoftmax(cudaStream_t stream, const int all_sequence_length, const } template -__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, - const int sequence_length, +__global__ void MaskedSoftmaxKernelSmall(const int total_sequence_length, const int* mask_end, const int* mask_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, bool causal) { @@ -527,27 +613,28 @@ __global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, __shared__ int end_position; if (threadIdx.x == 0) { - const int batch = blockIdx.y; + const int batch = blockIdx.z; start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); + end_position = min(total_sequence_length, mask_end[batch]); // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. if (start_position >= end_position) { start_position = 0; - end_position = all_sequence_length; + end_position = total_sequence_length; } } __syncthreads(); - SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); + SoftmaxSmall(total_sequence_length, end_position, start_position, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); } template -__device__ inline void SoftmaxSmallPacked(const int sequence_length, +__device__ inline void SoftmaxSmallPacked(const int total_sequence_length, const int end, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output) { using BlockReduce = cub::BlockReduce; @@ -556,23 +643,34 @@ __device__ inline void SoftmaxSmallPacked(const int sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * sequence_length; - const int index = offset + threadIdx.x; + + // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is within TPB. + const int sequence_length = gridDim.x; + const int num_heads = gridDim.y; + const int batch_size = gridDim.z; + const int s = blockIdx.x; + const int n = blockIdx.y; + const int b = blockIdx.z; + + // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) + int block = b * num_heads * sequence_length + n * sequence_length + s; + const int64_t offset = static_cast(block) * static_cast(total_sequence_length); + + const bool has_bias = (attn_bias != nullptr); + int64_t bias_offset = 0; + if (has_bias) { + // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) + block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + + (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + + s; + bias_offset = static_cast(block) * static_cast(total_sequence_length); + } + + int64_t index = offset + threadIdx.x; bool is_valid = threadIdx.x < end; - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const bool no_rpb = (rel_pos_bias == nullptr); - const int size_per_batch = gridDim.x * sequence_length; - float input_data = no_rpb - ? float(input[index]) - : float(input[index] + (broadcast_rel_pos_bias - ? rel_pos_bias[index % size_per_batch] - : rel_pos_bias[index])); + float input_data = has_bias ? float(input[index]) + float(attn_bias[bias_offset + threadIdx.x]) : float(input[index]); float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F); const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end); @@ -596,7 +694,7 @@ __device__ inline void SoftmaxSmallPacked(const int sequence_length, } __syncthreads(); - // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. + // threadIdx.x might be larger than total_sequence_length due to alignment to 32x. if (threadIdx.x < sequence_length) { output[index] = T(thread_data_exp * sum_reverse_block); } @@ -604,73 +702,79 @@ __device__ inline void SoftmaxSmallPacked(const int sequence_length, template __global__ void SoftmaxKernelSmallWithCumSeqLen(const T* input, - const T* rel_pos_bias, const bool broadcast_rel_pos_bias, - const int* cum_seq_length, const int sequence_length, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, + const int* cum_seq_length, + const int total_sequence_length, T* output) { __shared__ int end_position; if (threadIdx.x == 0) { - const int batch = blockIdx.y; + const int batch = blockIdx.z; end_position = cum_seq_length[batch + 1] - cum_seq_length[batch]; } __syncthreads(); - SoftmaxSmallPacked(sequence_length, end_position, - rel_pos_bias, broadcast_rel_pos_bias, - input, output); + SoftmaxSmallPacked(total_sequence_length, end_position, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } template __global__ void SoftmaxKernelWithCumSeqLen(const T* input, - const T* rel_pos_bias, const bool broadcast_rel_pos_bias, - const int* cum_seq_length, const int sequence_length, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, + const int* cum_seq_length, + const int total_sequence_length, T* output) { __shared__ int end_position; if (threadIdx.x == 0) { - const int batch = blockIdx.y; + const int batch = blockIdx.z; end_position = cum_seq_length[batch + 1] - cum_seq_length[batch]; } __syncthreads(); - Softmax(sequence_length, end_position, 0 /*start_position*/, - rel_pos_bias, broadcast_rel_pos_bias, input, output); + Softmax(total_sequence_length, end_position, 0 /*start_position*/, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } template -__global__ void MaskedSoftmaxKernel(const int all_sequence_length, +__global__ void MaskedSoftmaxKernel(const int total_sequence_length, const int* mask_end, const int* mask_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output) { __shared__ int start_position; __shared__ int end_position; if (threadIdx.x == 0) { - const int batch = blockIdx.y; + const int batch = blockIdx.z; start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); + end_position = min(total_sequence_length, mask_end[batch]); // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. if (start_position >= end_position) { start_position = 0; - end_position = all_sequence_length; + end_position = total_sequence_length; } } __syncthreads(); - Softmax(all_sequence_length, end_position, start_position, - rel_pos_bias, broadcast_rel_pos_bias, input, output); + Softmax(total_sequence_length, end_position, start_position, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } template -__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, - const int sequence_length, +__global__ void SoftmaxWithRawMaskSmallKernel(const int total_sequence_length, const int* attention_mask, const bool* key_padding_mask, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal, @@ -680,8 +784,8 @@ __global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const bool skip_softmax, const float mask_filter_value) { SoftmaxWithRawMaskSmall( - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, output, + total_sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal, rsqrt_head_size, mask_dimension, max_sequence_length, skip_softmax, mask_filter_value); } @@ -689,50 +793,52 @@ __global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, template Status ComputeSoftmaxWithCumSeqLength( const T* input, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const int32_t* cum_seq_length, const int batch_size, const int sequence_length, + const int total_sequence_length, const int num_heads, T* output, cudaStream_t stream) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); + const dim3 grid(sequence_length, num_heads, batch_size); if (sequence_length <= 32) { const int blockSize = 32; SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, output); } else if (sequence_length <= 64) { const int blockSize = 64; SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, output); } else if (sequence_length <= 128) { const int blockSize = 128; SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, output); } else if (sequence_length <= 256) { const int blockSize = 256; SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, output); } else if (sequence_length <= 512) { const int blockSize = 512; SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, output); } else if (sequence_length <= 1024) { const int blockSize = 1024; SoftmaxKernelSmallWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, output); } else { SoftmaxKernelWithCumSeqLen - <<>>(input, rel_pos_bias, broadcast_rel_pos_bias, - cum_seq_length, sequence_length, output); + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, output); } return CUDA_CALL(cudaGetLastError()); @@ -740,54 +846,62 @@ Status ComputeSoftmaxWithCumSeqLength( template Status ComputeSoftmaxWithMask1D(cudaStream_t stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* mask_index, const int* mask_start, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); + const dim3 grid(sequence_length, num_heads, batch_size); - if (all_sequence_length <= 32) { + if (total_sequence_length <= 32) { const int blockSize = 32; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 64) { + <<>>(total_sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 64) { const int blockSize = 64; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 128) { + <<>>(total_sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 128) { const int blockSize = 128; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 256) { + <<>>(total_sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 256) { const int blockSize = 256; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 512) { + <<>>(total_sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 512) { const int blockSize = 512; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); - } else if (all_sequence_length <= 1024) { + <<>>(total_sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 1024) { const int blockSize = 1024; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output, causal); + <<>>(total_sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); } else if (!causal) { const int blockSize = 1024; MaskedSoftmaxKernel - <<>>(all_sequence_length, mask_index, mask_start, - rel_pos_bias, broadcast_rel_pos_bias, input, output); + <<>>(total_sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention CUDA operator does not support total sequence length > 1024."); } @@ -797,14 +911,15 @@ Status ComputeSoftmaxWithMask1D(cudaStream_t stream, template Status ComputeSoftmaxWithRawMask(Stream* ort_stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* attention_mask, const bool* key_padding_mask, - const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const T* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal, @@ -815,58 +930,58 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, T* persistent_softmax_workspace, const float mask_filter_value) { auto stream = static_cast(ort_stream->GetHandle()); - const dim3 grid(sequence_length * num_heads, batch_size, 1); + const dim3 grid(sequence_length, num_heads, batch_size); T* out = use_persistent_softmax ? persistent_softmax_workspace : output; - if (all_sequence_length <= 32) { + if (total_sequence_length <= 32) { const int blockSize = 32; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, + <<>>(total_sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, mask_filter_value); - } else if (all_sequence_length <= 64) { + } else if (total_sequence_length <= 64) { const int blockSize = 64; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, + <<>>(total_sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, mask_filter_value); - } else if (all_sequence_length <= 128) { + } else if (total_sequence_length <= 128) { const int blockSize = 128; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, + <<>>(total_sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, mask_filter_value); - } else if (all_sequence_length <= 256) { + } else if (total_sequence_length <= 256) { const int blockSize = 256; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, + <<>>(total_sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, mask_filter_value); - } else if (all_sequence_length <= 512) { + } else if (total_sequence_length <= 512) { const int blockSize = 512; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, + <<>>(total_sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, mask_filter_value); - } else if (all_sequence_length <= 1024) { + } else if (total_sequence_length <= 1024) { const int blockSize = 1024; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, + <<>>(total_sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, mask_filter_value); } else { const int blockSize = 256; - const int sh_bytes = sizeof(float) * all_sequence_length; + const int sh_bytes = sizeof(float) * total_sequence_length; SoftmaxWithRawMaskLargeKernel <<>>( - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, + total_sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, mask_filter_value); } @@ -876,8 +991,8 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, ort_stream, output, persistent_softmax_workspace, - all_sequence_length, - all_sequence_length, + total_sequence_length, + total_sequence_length, batch_size * num_heads * sequence_length); } @@ -886,70 +1001,79 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, // Template Instantiation template Status ComputeSoftmax( - cudaStream_t stream, const int all_sequence_length, const int sequence_length, - const int batch_size, const int num_heads, const float* rel_pos_bias, - const bool broadcast_rel_pos_bias, float* input, float* output, bool causal); + cudaStream_t stream, const int total_sequence_length, const int sequence_length, + const int batch_size, const int num_heads, const float* attn_bias, + const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, + float* input, float* output, bool causal); template Status ComputeSoftmax( - cudaStream_t stream, const int all_sequence_length, const int sequence_length, - const int batch_size, const int num_heads, const half* rel_pos_bias, - const bool broadcast_rel_pos_bias, half* input, half* output, bool causal); + cudaStream_t stream, const int total_sequence_length, const int sequence_length, + const int batch_size, const int num_heads, const half* attn_bias, + const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, + half* input, half* output, bool causal); template Status ComputeSoftmaxWithCumSeqLength( const float* input, - const float* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const float* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const int32_t* cum_seq_length, const int batch_size, const int sequence_length, + const int total_sequence_length, const int num_heads, float* output, cudaStream_t stream); template Status ComputeSoftmaxWithCumSeqLength( const half* input, - const half* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const half* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const int32_t* cum_seq_length, const int batch_size, const int sequence_length, + const int total_sequence_length, const int num_heads, half* output, cudaStream_t stream); template Status ComputeSoftmaxWithMask1D(cudaStream_t stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* mask_index, const int* mask_start, - const float* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const float* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const float* input, float* output, const bool causal); template Status ComputeSoftmaxWithMask1D(cudaStream_t stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* mask_index, const int* mask_start, - const half* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const half* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const half* input, half* output, const bool causal); template Status ComputeSoftmaxWithRawMask(Stream* ort_stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* attention_mask, const bool* key_padding_mask, - const float* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const float* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const float* input, float* output, const bool causal, @@ -961,14 +1085,15 @@ template Status ComputeSoftmaxWithRawMask(Stream* ort_stream, const float mask_filter_value); template Status ComputeSoftmaxWithRawMask(Stream* ort_stream, - const int all_sequence_length, + const int total_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const int* attention_mask, const bool* key_padding_mask, - const half* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const half* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const half* input, half* output, const bool causal, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h index 46d2423fa7009..f7fab268b4607 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h @@ -10,16 +10,19 @@ namespace attention_softmax_cuda { template Status ComputeSoftmax(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, T* input, T* output, bool causal); + const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, + T* input, T* output, bool causal); template Status ComputeSoftmaxWithCumSeqLength( const T* input, const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const int32_t* cum_seq_length, const int batch_size, const int sequence_length, + const int total_sequence_length, const int num_heads, T* output, cudaStream_t stream); @@ -32,7 +35,8 @@ Status ComputeSoftmaxWithMask1D(cudaStream_t stream, const int* mask_index, const int* mask_start, const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal); @@ -46,7 +50,8 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, const int* attention_mask, const bool* key_padding_mask, const T* rel_pos_bias, - const bool broadcast_rel_pos_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, const T* input, T* output, const bool causal, diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index a5de20e44be1a..5029abe7e11e6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -184,35 +184,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.q_strideH = params.qk_head_size; p.k_strideH = params.qk_head_size; p.v_strideH = params.v_head_size; - p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; p.q_strideM = params.num_heads * params.qk_head_size; p.k_strideM = params.num_heads * params.qk_head_size; p.v_strideM = params.num_heads * params.v_head_size; p.o_strideM = params.num_heads * params.v_head_size; - p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; - p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; } else { // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH p.q_strideH = params.qk_head_size; p.k_strideH = params.max_sequence_length * params.qk_head_size; p.v_strideH = params.max_sequence_length * params.v_head_size; - p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; p.q_strideM = params.num_heads * params.qk_head_size; p.k_strideM = params.qk_head_size; p.v_strideM = params.v_head_size; p.o_strideM = params.num_heads * params.v_head_size; - p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; - p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } + + if (params.attn_bias != nullptr) { + auto& bias_dims = params.attn_bias_dims; + ORT_ENFORCE(bias_dims.size() == 4 && + (bias_dims[0] == 1 || bias_dims[0] == params.batch_size) && + (bias_dims[1] == 1 || bias_dims[1] == params.num_heads) && + bias_dims[2] == params.sequence_length && + bias_dims[3] == params.kv_sequence_length); + p.bias_strideH = p.num_queries * p.num_keys; + p.bias_strideM = p.num_keys; + p.bias_strideB = (bias_dims[0] == 1) ? 0 : (bias_dims[1] * p.num_queries * p.num_keys); + } else { + p.bias_strideH = 0; + p.bias_strideM = 0; + p.bias_strideB = 0; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index 08a562a12b844..918eec15f45b1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -25,8 +25,6 @@ struct MemoryEfficientAttentionParams { int32_t qk_head_size; int32_t v_head_size; bool causal; - // The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models. - bool is_attn_bias_batched; float scale; @@ -37,9 +35,11 @@ struct MemoryEfficientAttentionParams { const void* query; // [B, S, N, H] const void* key; // [B, L, N, H], where L is kv_sequence_length const void* value; // [B, L, N, H_v] - const void* attn_bias; // [N, S, S*] or null - void* output; // [B, S, N, H_v] - void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise + const void* attn_bias; // [B or 1, N or 1, S, L] or null + gsl::span attn_bias_dims; + + void* output; // [B, S, N, H_v] + void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise cudaStream_t stream; static bool need_workspace(size_t v_head_size, bool is_float) { diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu index c0b1996789183..65d2c113576f6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -140,21 +140,25 @@ Status DecoderQkvToContext( } constexpr bool is_unidirectional = false; - const T* add_before_softmax = nullptr; + const T* attention_bias = nullptr; + constexpr bool broadcast_attn_bias_dim_0 = false; + constexpr bool broadcast_attn_bias_dim_1 = false; + if (has_key_padding_mask) { constexpr int mask_dimension = 2; constexpr int max_sequence_length = 0; ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( ort_stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, key_padding_mask, add_before_softmax, - false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, + num_heads, nullptr, key_padding_mask, + attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + scratch1, scratch2, is_unidirectional, 1.0f, mask_dimension, max_sequence_length, false, nullptr, mask_filter_value)); } else { ORT_RETURN_IF_ERROR(ComputeSoftmax( stream, kv_sequence_length, sequence_length, batch_size, num_heads, - add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, - is_unidirectional)); + attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + scratch1, scratch2, is_unidirectional)); } // compute P*V (as V*P), and store in scratch3: BxNxSxH diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 3099b52cce13e..b694de48d2961 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -836,7 +836,6 @@ Status EfficientAttention( p.key = key; p.value = value; p.attn_bias = nullptr; - p.is_attn_bias_batched = false; p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 2835192abd298..58e41345431e1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -74,7 +74,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* value = context->Input(2); const Tensor* bias = context->Input(3); const Tensor* key_padding_mask = context->Input(4); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); const Tensor* past_key = context->Input(6); const Tensor* past_value = context->Input(7); @@ -87,7 +87,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { value, bias, key_padding_mask, - relative_position_bias, + attention_bias, past_key, past_value, nullptr, // past_seq_len @@ -150,7 +150,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && - nullptr == relative_position_bias && + nullptr == attention_bias && nullptr == key_padding_mask && parameters.head_size == parameters.v_head_size && onnxruntime::flash::is_supported(device_prop, @@ -188,7 +188,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_fused_cross_attention_ && nullptr == key_padding_mask && - nullptr == relative_position_bias && + nullptr == attention_bias && nullptr == past_key && nullptr == present_key && (parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) && parameters.hidden_size == parameters.v_hidden_size && @@ -212,7 +212,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_fused_self_attention_ && fused_cross_attention_kernel == nullptr && - nullptr == relative_position_bias && + nullptr == attention_bias && (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && nullptr == past_key && nullptr == present_key && is_mask_none_or_1d_k_len && @@ -243,16 +243,14 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length >= length_threshold || parameters.kv_sequence_length >= length_threshold; - // Check whether the relative position bias alignment is good for memory efficient attention. - bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; - bool use_memory_efficient_attention = !use_flash_attention && fused_runner == nullptr && fused_cross_attention_kernel == nullptr && !disable_memory_efficient_attention_ && is_long_sequence && - (relative_position_bias == nullptr || is_good_for_rpb) && + // Check whether the relative position bias alignment is good for memory efficient attention. + (attention_bias == nullptr || parameters.sequence_length % (4 * sizeof(T)) == 0) && (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && has_memory_efficient_attention(sm, std::is_same::value, parameters.head_size, parameters.v_head_size); @@ -270,7 +268,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); - data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + if (nullptr != attention_bias) { + data.attention_bias = reinterpret_cast(attention_bias->Data()); + data.attention_bias_dims = attention_bias->Shape().GetDims(); + } data.output = reinterpret_cast(output->MutableData()); data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index d1c6993d48e62..2a2df723e4f58 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -9,6 +9,7 @@ #include "contrib_ops/cuda/bert/packed_attention_impl.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -46,7 +47,7 @@ MHARunner* TrtFusedAttention::GetFusedRunner(const cudaDeviceProp& device_pro MHARunner* fused_runner = nullptr; bool use_fused_runner = !disable_fused_runner_ && - !parameters.has_relative_position_bias && + parameters.attention_bias_dims.empty() && parameters.hidden_size == parameters.v_hidden_size; if (!use_fused_runner) { @@ -104,7 +105,7 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const TensorShape& token_offset_shape, const TensorShape& cu_seq_len_shape, - const Tensor* relative_position_bias, + const Tensor* attention_bias, PackedAttentionParameters& parameters) const { // Abbreviation and Meanings: // T: token_count @@ -123,7 +124,7 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, // bias (Q/K/V) : (D + D + D_v) // token_offset : (B, S) // cu_seq_len_shape : (B + 1) - // relative_position_bias : (B, N, S, S), (1, N, S, S) or NULL + // attention_bias : (B, N, S, S), (1, N, S, S) or NULL const auto& input_dims = input_shape.GetDims(); if (input_dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -204,43 +205,13 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, v_hidden_size, "bias_dims[0]=", bias_dims[0]); } - bool broadcast_res_pos_bias = false; - if (relative_position_bias != nullptr) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - - if (relative_position_bias_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); - } - - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be same as batch_size or 1, got ", - relative_position_bias_dims[0]); - } - if (relative_position_bias_dims[0] == 1) { - broadcast_res_pos_bias = true; - } - - if (relative_position_bias_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); - } - - if (relative_position_bias_dims[2] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); - } - - if (relative_position_bias_dims[3] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as sequence_length, got ", - relative_position_bias_dims[3]); - } + gsl::span attention_bias_dims; + if (attention_bias != nullptr) { + attention_bias_dims = attention_bias->Shape().GetDims(); + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckAttentionBias( + attention_bias_dims, batch_size, num_heads, sequence_length, sequence_length)); } + parameters.attention_bias_dims = attention_bias_dims; parameters.batch_size = static_cast(batch_size); parameters.sequence_length = static_cast(sequence_length); @@ -252,8 +223,6 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, parameters.num_heads = num_heads; parameters.scale = this->GetScale(); parameters.token_count = static_cast(token_count); - parameters.has_relative_position_bias = nullptr != relative_position_bias; - parameters.broadcast_res_pos_bias = broadcast_res_pos_bias; return Status::OK(); } @@ -265,7 +234,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* token_offset = context->Input(3); const Tensor* cumulative_sequence_length = context->Input(4); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); PackedAttentionParameters parameters; parameters.use_tf32 = this->UseTF32(); @@ -274,7 +243,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { bias->Shape(), token_offset->Shape(), cumulative_sequence_length->Shape(), - relative_position_bias, + attention_bias, parameters)); TensorShapeVector output_shape{parameters.token_count, parameters.v_hidden_size}; @@ -287,9 +256,8 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { #if USE_MEMORY_EFFICIENT_ATTENTION if (nullptr == fused_runner) { int sm = device_prop.major * 10 + device_prop.minor; - bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; use_memory_efficient_attention = - is_good_for_rpb && + (attention_bias == nullptr || parameters.sequence_length % (4 * sizeof(T)) == 0) && sizeof(T) == 2 && // only enable for fp16 has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); } @@ -346,7 +314,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { PackedAttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); data.bias = reinterpret_cast(bias->Data()); - data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + data.attention_bias = (nullptr == attention_bias) ? nullptr : reinterpret_cast(attention_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.token_offset = token_offset->Data(); data.cumulative_sequence_length = cumulative_sequence_length->Data(); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 2521cd49b5482..890413b82d23f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -523,8 +523,8 @@ Status FusedScaledDotProductAttentionCutlass( p.query = query; p.key = key; p.value = value; - p.attn_bias = data.relative_position_bias; - p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; + p.attn_bias = data.attention_bias; + p.attn_bias_dims = parameters.attention_bias_dims; p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? accum_workspace : nullptr; p.stream = stream; @@ -603,14 +603,19 @@ Status UnfusedScaledDotProductAttention( sequence_length); T* attention_score = scaled_qk + (bytes / element_size); + bool broadcast_attn_bias_dim_0 = parameters.attention_bias_dims.size() > 0 && parameters.attention_bias_dims[0] == 1; + bool broadcast_attn_bias_dim_1 = parameters.attention_bias_dims.size() > 1 && parameters.attention_bias_dims[1] == 1; + // Apply softmax and store result R to attention_score: BxNxSxS ORT_RETURN_IF_ERROR(ComputeSoftmaxWithCumSeqLength( scaled_qk, - data.relative_position_bias, - parameters.broadcast_res_pos_bias, + data.attention_bias, + broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1, data.cumulative_sequence_length, batch_size, sequence_length, + sequence_length, // total sequence length num_heads, attention_score, stream)); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h index 629ca59c73f16..1126c8a046da9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h @@ -33,7 +33,7 @@ template struct PackedAttentionData { T* gemm_buffer; const T* bias; - const T* relative_position_bias; + const T* attention_bias; const int32_t* token_offset; const int32_t* cumulative_sequence_length; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 53e96fc732a33..f9714e00c493f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -10,6 +10,7 @@ #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -54,7 +55,7 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, const Tensor* bias, const TensorShape& token_offset_shape, const TensorShape& cu_seq_len_shape, - const Tensor* relative_position_bias, + const Tensor* attention_bias, PackedAttentionParameters& parameters) const { // Shapes of inputs and output: // When Q, K and V are not packed: @@ -67,7 +68,7 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, // Input 'value': None // Input 'token_offset': (batch_size, sequence_length) // Input 'cumulative_sequence_length': (batch_size + 1) - // Input 'relative_position_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None + // Input 'attention_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None // Output 'output': (token_count, v_hidden_size) const auto& query_dims = query_shape.GetDims(); @@ -147,45 +148,15 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, "Input 'cumulative_sequence_length' should have 1 dimension with size equal to batch_size + 1"); } - // TODO(tianleiwu): move relative position bias shape checker to a helper function. It is shared by multiple ops. const int num_heads = this->GetNumHeads(); - bool broadcast_res_pos_bias = false; - if (relative_position_bias != nullptr) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - if (relative_position_bias_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); - } - - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be same as batch_size or 1, got ", - relative_position_bias_dims[0]); - } - if (relative_position_bias_dims[0] == 1 && 1 != batch_size) { - broadcast_res_pos_bias = true; - } - - if (relative_position_bias_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); - } - - if (relative_position_bias_dims[2] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); - } - - if (relative_position_bias_dims[3] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as sequence_length, got ", - relative_position_bias_dims[3]); - } + gsl::span attention_bias_dims; + if (attention_bias != nullptr) { + attention_bias_dims = attention_bias->Shape().GetDims(); + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckAttentionBias( + attention_bias_dims, batch_size, num_heads, sequence_length, sequence_length)); } + parameters.attention_bias_dims = attention_bias_dims; parameters.batch_size = static_cast(batch_size); parameters.sequence_length = static_cast(sequence_length); @@ -197,8 +168,6 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, parameters.num_heads = num_heads; parameters.scale = this->GetScale(); parameters.token_count = static_cast(token_count); - parameters.has_relative_position_bias = (nullptr != relative_position_bias); - parameters.broadcast_res_pos_bias = broadcast_res_pos_bias; return Status::OK(); } @@ -211,7 +180,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co const Tensor* bias = context->Input(3); const Tensor* token_offset = context->Input(4); const Tensor* cumulative_sequence_length = context->Input(5); - const Tensor* relative_position_bias = context->Input(6); + const Tensor* attention_bias = context->Input(6); PackedAttentionParameters parameters; parameters.use_tf32 = this->UseTF32(); @@ -221,7 +190,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co bias, token_offset->Shape(), cumulative_sequence_length->Shape(), - relative_position_bias, + attention_bias, parameters)); TensorShapeVector output_shape{parameters.token_count, parameters.v_hidden_size}; @@ -232,7 +201,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co bool use_flash_attention = false; #if USE_FLASH_ATTENTION if (!disable_flash_attention_) { - use_flash_attention = !parameters.has_relative_position_bias && + use_flash_attention = nullptr == attention_bias && parameters.head_size == parameters.v_head_size && onnxruntime::flash::is_supported(device_prop, parameters.head_size, @@ -254,9 +223,9 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co #if USE_MEMORY_EFFICIENT_ATTENTION if (!use_flash_attention && nullptr == fused_runner && !disable_memory_efficient_attention_) { int sm = device_prop.major * 10 + device_prop.minor; - bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; + bool is_attn_bias_aligned = nullptr == attention_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; use_memory_efficient_attention = - is_good_for_rpb && + is_attn_bias_aligned && (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); } @@ -304,9 +273,9 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co data.key = (key == nullptr) ? nullptr : reinterpret_cast(key->Data()); data.value = (value == nullptr) ? nullptr : reinterpret_cast(value->Data()); data.bias = (bias == nullptr) ? nullptr : reinterpret_cast(bias->Data()); - data.relative_position_bias = (nullptr == relative_position_bias) - ? nullptr - : reinterpret_cast(relative_position_bias->Data()); + data.attention_bias = (nullptr == attention_bias) + ? nullptr + : reinterpret_cast(attention_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.token_offset = token_offset->Data(); data.cumulative_sequence_length = cumulative_sequence_length->Data(); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index e5a4c54f48903..7bcb589c1d98b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -701,8 +701,8 @@ Status FusedAttentionCutlass( p.query = data.no_qkv_workspace ? data.query : data.workspace; p.key = data.no_qkv_workspace ? data.key : (data.workspace + elements_qk); p.value = data.no_qkv_workspace ? data.value : (data.workspace + elements_qk + elements_qk); - p.attn_bias = data.relative_position_bias; - p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; + p.attn_bias = data.attention_bias; + p.attn_bias_dims = parameters.attention_bias_dims; p.output = data.output; p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) @@ -788,14 +788,18 @@ Status UnfusedAttention( sequence_length); T* attention_score = scaled_qk + (bytes / element_size); + bool broadcast_attn_bias_dim_0 = parameters.attention_bias_dims.size() > 0 && parameters.attention_bias_dims[0] == 1; + bool broadcast_attn_bias_dim_1 = parameters.attention_bias_dims.size() > 1 && parameters.attention_bias_dims[1] == 1; // Apply softmax and store result R to attention_score: BxNxSxS ORT_RETURN_IF_ERROR(ComputeSoftmaxWithCumSeqLength( scaled_qk, - data.relative_position_bias, - parameters.broadcast_res_pos_bias, + data.attention_bias, + broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1, data.cumulative_sequence_length, batch_size, sequence_length, + sequence_length, // total sequence length num_heads, attention_score, stream)); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h index eeca72f16e64e..9d0ff77e5fcaa 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h @@ -17,7 +17,8 @@ struct PackedMultiHeadAttentionData { const T* key; const T* value; const T* bias; - const T* relative_position_bias; + const T* attention_bias; + const int32_t* token_offset; const int32_t* cumulative_sequence_length; diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index 6d52ff7282799..5c39cf56dfd92 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -335,6 +335,29 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } } +void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} +void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, gsl::span& dims) const { + PrintTensorByDims(this, name, tensor, dims); +} + #else CudaTensorConsoleDumper::CudaTensorConsoleDumper() { } @@ -410,6 +433,25 @@ void CudaTensorConsoleDumper::Print(const char*, int, bool) const { void CudaTensorConsoleDumper::Print(const char*, const std::string&, bool) const { } + +void CudaTensorConsoleDumper::Print(const char*, const int32_t*, gsl::span&) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const int64_t*, gsl::span&) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const float*, gsl::span&) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const half*, gsl::span&) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, gsl::span&) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, gsl::span&) const { +} + #endif } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 4f41161cd4a31..631421b1623be 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -21,26 +21,32 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; + void Print(const char* name, const int32_t* tensor, gsl::span& dims) const override; void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; + void Print(const char* name, const int64_t* tensor, gsl::span& dims) const override; void Print(const char* name, const float* tensor, int dim0, int dim1) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const override; - - void Print(const char* name, const half* tensor, int dim0, int dim1) const; - void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; - void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; + void Print(const char* name, const float* tensor, gsl::span& dims) const override; void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const override; + void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const override; + + void Print(const char* name, const half* tensor, int dim0, int dim1) const; + void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; + void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; + void Print(const char* name, const half* tensor, gsl::span& dims) const; void Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const; void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const; void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + void Print(const char* name, const BFloat16* tensor, gsl::span& dims) const; void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu index 96cc17734874c..473ab8dd3ce4d 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cu @@ -53,7 +53,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); auto& device_prop = GetDeviceProp(); @@ -63,7 +63,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bias->Shape(), mask_index, past, - relative_position_bias, + attention_bias, &attn, device_prop.maxThreadsPerBlock, past_seq_len)); @@ -190,8 +190,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { params.v_buffer = v_buffer; params.out_buffer = reinterpret_cast(output->MutableDataRaw()); - if (relative_position_bias != nullptr) { - params.bias_buffer = reinterpret_cast(relative_position_bias->DataRaw()); + if (attention_bias != nullptr) { + params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); } if (mask_index != nullptr) { diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 54dda4bfa6d2c..e013f35e150c4 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -398,7 +398,8 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { const T* v_buffer; T* out_buffer; - // optional, bias [B,N,S,T] + // optional, attention bias [B,N,S,T] + // TODO: support shape [B,1,S,T], [1, N, S, T], [1, 1, S, T] with broadcast. const T* bias_buffer{nullptr}; // optional, mask value diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 5997daaca6e8a..b07f9214e340e 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -87,7 +87,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* bias{}; const Tensor* key_padding_mask{}; - const Tensor* relative_position_bias{}; + const Tensor* attention_bias{}; const Tensor* past_key{}; const Tensor* past_value{}; const Tensor* past_seq_len{}; @@ -95,12 +95,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { if (attn_type_ == kMultiHeadAttention) { bias = context->Input(3); key_padding_mask = context->Input(4); - relative_position_bias = context->Input(5); + attention_bias = context->Input(5); past_key = context->Input(6); past_value = context->Input(7); } else if (attn_type_ == kDecoderMaskedMultiHeadAttention) { key_padding_mask = context->Input(3); - relative_position_bias = context->Input(4); + attention_bias = context->Input(4); past_key = context->Input(5); past_value = context->Input(6); past_seq_len = context->Input(kPastSequenceLengthInputIndex); @@ -120,7 +120,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR( multihead_attention_helper::CheckInputs( query, key, value, bias, - key_padding_mask, relative_position_bias, + key_padding_mask, attention_bias, past_key, past_value, past_seq_len, &attn, num_heads_, mask_filter_value_, scale_, false, /*is_unidirectional_*/ @@ -263,8 +263,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { params.mask_index_dims = key_padding_mask->Shape().AsShapeVector(); } - if (relative_position_bias != nullptr) { - params.bias_buffer = reinterpret_cast(relative_position_bias->DataRaw()); + if (attention_bias != nullptr) { + params.bias_buffer = reinterpret_cast(attention_bias->DataRaw()); } params.workspace_buffer = reinterpret_cast(workspace.get()); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7272a949f7218..0745dcdf231e6 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1006,9 +1006,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M", OpSchema::Optional) .Input(5, - "relative_position_bias", - "relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)" - " or (1, num_heads, sequence_length, total_sequence_length)", + "attention_bias", + "bias added to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) .Input(6, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 14a7383e67897..788293464d3b3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -2371,7 +2371,7 @@ constexpr DML_SCHEMA_FIELD DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA_FIELDS[18] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "StackedQueryKeyValueTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MaskTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RelativePositionBiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AttentionBiasTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastKeyTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastValueTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -2502,7 +2502,7 @@ constexpr DML_SCHEMA_FIELD DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA_FIELDS[20] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "StackedQueryKeyValueTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MaskTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RelativePositionBiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AttentionBiasTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastKeyTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastValueTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastSequenceLengthsTensor", true }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 23b5a491c7d96..04ad595b241b0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -1471,7 +1471,7 @@ inline std::vector GetFields(const DML_MULTIHEAD_ATTENTION_OPERAT OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StackedQueryKeyValueTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.MaskTensor))), - OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.RelativePositionBiasTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.AttentionBiasTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.PastKeyTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.PastValueTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.OutputTensor))), @@ -1566,7 +1566,7 @@ inline std::vector GetFields(const DML_MULTIHEAD_ATTENTION1_OPERA OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StackedQueryKeyValueTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.MaskTensor))), - OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.RelativePositionBiasTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.AttentionBiasTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.PastKeyTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.PastValueTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.PastSequenceLengthsTensor))), diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index 73c2d57e984af..5409d1c653d47 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -47,7 +47,7 @@ class DmlOperatorAttention : public DmlOperator mhaStackedQueryKeyValueIndex, mhaBiasIndex, mhaMaskIndex, - mhaRelativePositionBiasIndex, + mhaAttentionBiasIndex, mhaPastKeyIndex, mhaPastValueIndex, mhaInputCount, @@ -60,7 +60,7 @@ class DmlOperatorAttention : public DmlOperator biasIndex, maskIndex, pastIndex, - relativePositionBiasIndex, + attentionBiasIndex, pastSequenceLengthIndex, inputCount, }; @@ -78,12 +78,12 @@ class DmlOperatorAttention : public DmlOperator const uint32_t dmlWeightsIndex = weightsIndex; const uint32_t dmlBiasIndex = biasIndex; const uint32_t dmlMaskIndex = maskIndex; - const uint32_t dmlRelativePositionBiasIndex = relativePositionBiasIndex; + const uint32_t dmlAttentionBiasIndex = attentionBiasIndex; const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1; - const bool hasRelativePositionBias = kernelCreationContext.IsInputValid(relativePositionBiasIndex); + const bool hasAttentionBias = kernelCreationContext.IsInputValid(attentionBiasIndex); DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1); @@ -188,13 +188,13 @@ class DmlOperatorAttention : public DmlOperator } } - if (hasRelativePositionBias) + if (hasAttentionBias) { - auto relativePositionBiasTensorShape = m_inputTensorDescs[dmlRelativePositionBiasIndex].GetSizes(); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape.size() == 4); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape[0] == inputTensorShape[0]); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape[1] == numHeads); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape[2] == inputTensorShape[1]); + auto attentionBiasTensorShape = m_inputTensorDescs[dmlAttentionBiasIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape.size() == 4); + ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape[0] == inputTensorShape[0]); + ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape[2] == inputTensorShape[1]); } TensorDesc firstGemmOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape); @@ -346,7 +346,7 @@ class DmlOperatorAttention : public DmlOperator mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[dmlMaskIndex] : nullptr; } - mhaOperatorDesc.RelativePositionBiasTensor = hasRelativePositionBias ? &inputDescs[dmlRelativePositionBiasIndex] : nullptr; + mhaOperatorDesc.AttentionBiasTensor = hasAttentionBias ? &inputDescs[dmlAttentionBiasIndex] : nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); @@ -452,13 +452,13 @@ class DmlOperatorAttention : public DmlOperator } } - if (hasRelativePositionBias) + if (hasAttentionBias) { - DML_INPUT_GRAPH_EDGE_DESC relativePositionBiasToMhaEdge = {}; - relativePositionBiasToMhaEdge.GraphInputIndex = dmlRelativePositionBiasIndex; - relativePositionBiasToMhaEdge.ToNodeIndex = mhaNodeIndex; - relativePositionBiasToMhaEdge.ToNodeInputIndex = mhaRelativePositionBiasIndex; - inputEdges.push_back(relativePositionBiasToMhaEdge); + DML_INPUT_GRAPH_EDGE_DESC attentionBiasToMhaEdge = {}; + attentionBiasToMhaEdge.GraphInputIndex = dmlAttentionBiasIndex; + attentionBiasToMhaEdge.ToNodeIndex = mhaNodeIndex; + attentionBiasToMhaEdge.ToNodeInputIndex = mhaAttentionBiasIndex; + inputEdges.push_back(attentionBiasToMhaEdge); } if (hasSlicedValue) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp index cde08864ca54e..96d2408f118a6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -18,7 +18,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator valueIndex, biasIndex, maskIndex, - relativePositionBiasIndex, + attentionBiasIndex, pastKeyIndex, pastValueIndex, inputCount, @@ -34,7 +34,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator dmlStackedQueryKeyValueIndex, dmlBiasIndex, dmlMaskIndex, - dmlRelativePositionBiasIndex, + dmlAttentionBiasIndex, dmlPastKeyIndex, dmlPastValueIndex, dmlInputCount, @@ -55,7 +55,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator const bool hasValue = kernelCreationContext.IsInputValid(valueIndex) && !keyValueIsPast; const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); - const bool hasRelativePositionBias = kernelCreationContext.IsInputValid(relativePositionBiasIndex); + const bool hasAttentionBias = kernelCreationContext.IsInputValid(attentionBiasIndex); const bool hasPastKey = keyValueIsPast || (kernelCreationContext.IsInputValid(pastKeyIndex) && kernelCreationContext.GetInputTensorShape(pastKeyIndex)[2] != 0); const bool hasPastValue = keyValueIsPast || (kernelCreationContext.IsInputValid(pastValueIndex) && kernelCreationContext.GetInputTensorShape(pastValueIndex)[2] != 0); const bool hasPresentKeyOutput = kernelCreationContext.IsOutputValid(outputPresentKeyIndex); @@ -73,7 +73,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator stackedQkv ? std::optional(queryIndex) : std::nullopt, biasIndex, hasMask ? std::optional(maskIndex) : std::nullopt, - relativePositionBiasIndex, + attentionBiasIndex, hasPastKey ? std::optional(keyValueIsPast ? keyIndex : pastKeyIndex) : std::nullopt, hasPastValue ? std::optional(keyValueIsPast ? valueIndex : pastValueIndex) : std::nullopt, }; @@ -243,15 +243,15 @@ class DmlOperatorMultiHeadAttention : public DmlOperator } } - if (hasRelativePositionBias) + if (hasAttentionBias) { - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlRelativePositionBiasIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlAttentionBiasIndex].GetDimensionCount() == 4); - auto relativePositionBiasSizes = m_inputTensorDescs[dmlRelativePositionBiasIndex].GetSizes(); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[0] == batchSize); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[1] == numHeads); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[2] == sequenceLength); - ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[3] == totalSequenceLength); + auto attentionBiasSizes = m_inputTensorDescs[dmlAttentionBiasIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(attentionBiasSizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(attentionBiasSizes[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(attentionBiasSizes[2] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(attentionBiasSizes[3] == totalSequenceLength); } if (hasPastKey) @@ -283,7 +283,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator mhaDesc.StackedQueryKeyValueTensor = stackedQkv ? &inputDescs[dmlStackedQueryKeyValueIndex] : nullptr; mhaDesc.BiasTensor = hasBias ? &inputDescs[dmlBiasIndex] : nullptr; mhaDesc.MaskTensor = hasMask ? &inputDescs[dmlMaskIndex] : nullptr; - mhaDesc.RelativePositionBiasTensor = hasRelativePositionBias ? &inputDescs[dmlRelativePositionBiasIndex] : nullptr; + mhaDesc.AttentionBiasTensor = hasAttentionBias ? &inputDescs[dmlAttentionBiasIndex] : nullptr; mhaDesc.PastKeyTensor = hasPastKey ? &inputDescs[dmlPastKeyIndex] : nullptr; mhaDesc.PastValueTensor = hasPastValue ? &inputDescs[dmlPastValueIndex] : nullptr; mhaDesc.OutputTensor = &outputDescs[outputIndex]; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index f9519b26bb4e3..f45cb6c90b352 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -89,7 +89,7 @@ class DmlOperatorQAttention : public DmlOperator mhaStackedQueryKeyValueIndex, mhaBiasIndex, mhaMaskIndex, - mhaRelativePositionBiasIndex, + mhaAttentionBiasIndex, mhaPastKeyIndex, mhaPastValueIndex, mhaInputCount, @@ -415,10 +415,10 @@ class DmlOperatorQAttention : public DmlOperator mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[maskIndex] : nullptr; } - mhaOperatorDesc.RelativePositionBiasTensor = nullptr; + mhaOperatorDesc.AttentionBiasTensor = nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); - // Set MaskFilterValue to lowest float for Causal Mask + // Set MaskFilterValue to lowest float for Causal Mask mhaOperatorDesc.MaskFilterValue = unidirectional ? std::numeric_limits::lowest() : kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); mhaOperatorDesc.HeadCount = numHeads; diff --git a/onnxruntime/python/tools/transformers/constants.py b/onnxruntime/python/tools/transformers/constants.py index fc8f2cc2f58d3..0da22dc149968 100644 --- a/onnxruntime/python/tools/transformers/constants.py +++ b/onnxruntime/python/tools/transformers/constants.py @@ -21,7 +21,7 @@ class AttentionInputIDs: BIAS = 2 MASK_INDEX = 3 PAST = 4 - RELATIVE_POSITION_BIAS = 5 + ATTENTION_BIAS = 5 PAST_SEQUENCE_LENGTH = 6 @@ -36,7 +36,7 @@ class MultiHeadAttentionInputIDs: VALUE = 2 BIAS = 3 KEY_PADDING_MASK = 4 - RELATIVE_POSITION_BIAS = 5 + ATTENTION_BIAS = 5 PAST_KEY = 6 PAST_VALUE = 7 diff --git a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py index 4da97f0de7bed..e854312cae826 100644 --- a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py +++ b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py @@ -184,9 +184,9 @@ def _are_attentions_supported(self) -> bool: def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None: for attention in self.attention_nodes: - relative_pos_bias = ( - attention.input[AttentionInputIDs.RELATIVE_POSITION_BIAS] - if len(attention.input) > AttentionInputIDs.RELATIVE_POSITION_BIAS + attention_bias = ( + attention.input[AttentionInputIDs.ATTENTION_BIAS] + if len(attention.input) > AttentionInputIDs.ATTENTION_BIAS else "" ) packed_attention = helper.make_node( @@ -197,7 +197,7 @@ def _replace_attention_with_packing_attention(self, token_offset: str, cumulativ attention.input[AttentionInputIDs.BIAS], token_offset, cumulative_sequence_length, - relative_pos_bias, + attention_bias, ], outputs=[attention.output[AttentionOutputIDs.OUTPUT]], name=self.model.create_node_name(Operators.PACKEDATTENTION), @@ -261,9 +261,9 @@ def _are_attentions_supported(self) -> bool: def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None: gated_relative_pos_bias_count = 0 for mha in self.attention_nodes: - relative_pos_bias = ( - mha.input[MultiHeadAttentionInputIDs.RELATIVE_POSITION_BIAS] - if len(mha.input) > MultiHeadAttentionInputIDs.RELATIVE_POSITION_BIAS + attention_bias = ( + mha.input[MultiHeadAttentionInputIDs.ATTENTION_BIAS] + if len(mha.input) > MultiHeadAttentionInputIDs.ATTENTION_BIAS else "" ) packed_mha = helper.make_node( @@ -275,7 +275,7 @@ def _replace_attention_with_packing_attention(self, token_offset: str, cumulativ mha.input[MultiHeadAttentionInputIDs.BIAS], token_offset, cumulative_sequence_length, - relative_pos_bias, + attention_bias, ], outputs=[mha.output[MultiHeadAttentionOutputIDs.OUTPUT]], name=self.model.create_node_name(Operators.PACKED_MULTI_HEAD_ATTENTION), @@ -293,8 +293,8 @@ def _replace_attention_with_packing_attention(self, token_offset: str, cumulativ self.node_name_to_graph_name[packed_mha.name] = self.this_graph_name # Append token_offset input to GatedRelativePositionBias - if relative_pos_bias: - rel_pos_bias_node = self.model.get_parent(mha, MultiHeadAttentionInputIDs.RELATIVE_POSITION_BIAS) + if attention_bias: + rel_pos_bias_node = self.model.get_parent(mha, MultiHeadAttentionInputIDs.ATTENTION_BIAS) if ( rel_pos_bias_node and rel_pos_bias_node.op_type == "GatedRelativePositionBias" diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc index 79e1a8f0fdc19..1ea67314f62d6 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc @@ -297,7 +297,7 @@ void GetCrossAttentionDataWithPast(AttentionTestData& data) { data.fp16_output_data = data.fp32_output_data; } -void GetSelfAttentionData_WithPast_WithRelPosBias_ForT5(AttentionTestData& data) { +void GetSelfAttentionData_WithPast_WithAttnBias_ForT5(AttentionTestData& data) { data.hidden_size = 8; data.v_hidden_size = 8; data.num_heads = 2; @@ -313,21 +313,21 @@ void GetSelfAttentionData_WithPast_WithRelPosBias_ForT5(AttentionTestData& data) AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention, }; - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.query_data", data.query_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.key_data", data.key_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.value_data", data.value_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.rel_pos_bias_data", data.rel_pos_bias_data); - data.broadcast_rel_pos_bias = false; - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_key_data", data.past_key_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_value_data", data.past_value_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.fp32_output_data", data.fp32_output_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.query_data", data.query_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.key_data", data.key_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.value_data", data.value_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.attention_bias_data", data.attention_bias_data); + data.broadcast_attention_bias = false; + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.past_key_data", data.past_key_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.past_value_data", data.past_value_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_key_data", data.present_key_data); - LoadTensor("SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_value_data", data.present_value_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.present_key_data", data.present_key_data); + LoadTensor("SelfAttentionData_WithPast_WithAttnBias_ForT5.present_value_data", data.present_value_data); data.is_static_kv = false; } -void GetAttentionDataCutlassRelPosBias(AttentionTestData& data) { +void GetAttentionDataCutlassAttnBias(AttentionTestData& data) { data.hidden_size = 8; data.v_hidden_size = 8; data.num_heads = 2; @@ -343,13 +343,13 @@ void GetAttentionDataCutlassRelPosBias(AttentionTestData& data) { AttentionKernelType::AttentionKernel_TrtFusedCrossAttention, AttentionKernelType::AttentionKernel_TrtFusedAttention}; - LoadTensor("AttentionDataCutlassRelPosBias.query_data", data.query_data); - LoadTensor("AttentionDataCutlassRelPosBias.key_data", data.key_data); - LoadTensor("AttentionDataCutlassRelPosBias.value_data", data.value_data); - LoadTensor("AttentionDataCutlassRelPosBias.bias_data", data.bias_data); - LoadTensor("AttentionDataCutlassRelPosBias.rel_pos_bias_data", data.rel_pos_bias_data); - data.broadcast_rel_pos_bias = false; - LoadTensor("AttentionDataCutlassRelPosBias.fp16_output_data", data.fp16_output_data); + LoadTensor("AttentionDataCutlassAttnBias.query_data", data.query_data); + LoadTensor("AttentionDataCutlassAttnBias.key_data", data.key_data); + LoadTensor("AttentionDataCutlassAttnBias.value_data", data.value_data); + LoadTensor("AttentionDataCutlassAttnBias.bias_data", data.bias_data); + LoadTensor("AttentionDataCutlassAttnBias.attention_bias_data", data.attention_bias_data); + data.broadcast_attention_bias = false; + LoadTensor("AttentionDataCutlassAttnBias.fp16_output_data", data.fp16_output_data); data.fp32_output_data = {}; data.is_static_kv = false; } @@ -417,7 +417,7 @@ void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestDat data.is_static_kv = true; } -void GetSelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias(AttentionTestData& data) { +void GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(AttentionTestData& data) { data.hidden_size = 8; data.v_hidden_size = 8; data.num_heads = 2; @@ -433,19 +433,19 @@ void GetSelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias(AttentionTestDa AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention, }; - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.query_data", data.query_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.key_data", data.key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.value_data", data.value_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.bias_data", data.bias_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.past_key_data", data.past_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.past_value_data", data.past_value_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.fp32_output_data", data.fp32_output_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.present_key_data", data.present_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.present_value_data", data.present_value_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.query_data", data.query_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.key_data", data.key_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.value_data", data.value_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.bias_data", data.bias_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.past_key_data", data.past_key_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.past_value_data", data.past_value_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.fp32_output_data", data.fp32_output_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.present_key_data", data.present_key_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.present_value_data", data.present_value_data); data.is_static_kv = false; } -void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias(AttentionTestData& data) { +void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(AttentionTestData& data) { data.hidden_size = 16; data.v_hidden_size = 16; data.num_heads = 2; @@ -461,37 +461,37 @@ void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias(Atten AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention, }; - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.query_data", data.query_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.key_data", data.key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.value_data", data.value_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.bias_data", data.bias_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.past_key_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.query_data", data.query_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.key_data", data.key_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.value_data", data.value_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.bias_data", data.bias_data); + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.past_key_data", data.past_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.past_value_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.past_value_data", data.past_value_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.fp32_output_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_key_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.present_key_data", data.present_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_value_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.present_value_data", data.present_value_data); data.is_static_kv = false; } -void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias(AttentionTestData& data) { - GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias(data); +void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(AttentionTestData& data) { + GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(data); data.bias_data.clear(); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.past_key_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.past_key_data", data.past_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.past_value_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.past_value_data", data.past_value_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.fp32_output_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.fp32_output_data", data.fp32_output_data); data.fp16_output_data = data.fp32_output_data; - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.present_key_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.present_key_data", data.present_key_data); - LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.present_value_data", + LoadTensor("SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.present_value_data", data.present_value_data); data.is_static_kv = false; } @@ -535,7 +535,7 @@ void GetAttentionDataWithNeoXRotaryEmbedding(std::vector& input, LoadTensor("AttentionDataWithNeoXRotaryEmbedding.output", output); } -void GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(PackedAttentionTestData& data) { +void GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(PackedAttentionTestData& data) { data.hidden_size = 32; data.v_hidden_size = 32; data.num_heads = 1; @@ -550,19 +550,19 @@ void GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(PackedAttent data.skip_kernel_types = { AttentionKernelType::AttentionKernel_TrtFusedCrossAttention}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.query_data", data.query_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.key_data", data.key_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.value_data", data.value_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.query_data", data.query_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.key_data", data.key_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.value_data", data.value_data); data.bias_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.qkv_data", data.qkv_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.qkv_data", data.qkv_data); // Do not test fp32 data.fp32_output_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.fp16_output_data", data.fp16_output_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.fp16_output_data", data.fp16_output_data); } -void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias(PackedAttentionTestData& data) { +void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias(PackedAttentionTestData& data) { data.hidden_size = 16; data.v_hidden_size = 16; data.num_heads = 2; @@ -576,23 +576,23 @@ void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias(PackedAttention data.skip_kernel_types = { AttentionKernelType::AttentionKernel_TrtFusedCrossAttention}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.query_data", data.query_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.key_data", data.key_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.value_data", data.value_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.query_data", data.query_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.key_data", data.key_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.value_data", data.value_data); data.bias_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.qkv_data", data.qkv_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.qkv_data", data.qkv_data); // shape: batch_size, num_heads, sequence_length, sequence_length - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.rel_pos_bias_data", data.rel_pos_bias_data); - data.broadcast_rel_pos_bias = false; + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.attention_bias_data", data.attention_bias_data); + data.broadcast_attention_bias = false; // Do not test fp32 data.fp32_output_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.fp16_output_data", data.fp16_output_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.fp16_output_data", data.fp16_output_data); } -void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias(PackedAttentionTestData& data) { +void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias(PackedAttentionTestData& data) { data.hidden_size = 16; data.v_hidden_size = 16; data.num_heads = 2; @@ -606,21 +606,21 @@ void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias(Packed data.skip_kernel_types = { AttentionKernelType::AttentionKernel_TrtFusedCrossAttention}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.query_data", data.query_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.key_data", data.key_data); - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.value_data", data.value_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.query_data", data.query_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.key_data", data.key_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.value_data", data.value_data); data.bias_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.qkv_data", data.qkv_data); + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.qkv_data", data.qkv_data); // shape: 1, num_heads, sequence_length, sequence_length - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.rel_pos_bias_data", - data.rel_pos_bias_data); - data.broadcast_rel_pos_bias = true; + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.attention_bias_data", + data.attention_bias_data); + data.broadcast_attention_bias = true; // Do not test fp32 data.fp32_output_data = {}; - LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.fp16_output_data", + LoadTensor("PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.fp16_output_data", data.fp16_output_data); } diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.h b/onnxruntime/test/contrib_ops/attention_op_test_helper.h index ee93cdca0cd82..b0dbe6e7b4ac7 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.h +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.h @@ -27,8 +27,8 @@ struct BaseAttentionTestData { std::vector qkv_data; std::vector bias_data; - std::vector rel_pos_bias_data; - bool broadcast_rel_pos_bias; + std::vector attention_bias_data; + bool broadcast_attention_bias; std::vector past_key_data; std::vector past_value_data; @@ -76,29 +76,29 @@ void GetCrossAttentionData_HeadSize8(AttentionTestData& data); void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data); void GetCrossAttentionDataWithPast(AttentionTestData& data); -void GetSelfAttentionData_WithPast_WithRelPosBias_ForT5(AttentionTestData& data); +void GetSelfAttentionData_WithPast_WithAttnBias_ForT5(AttentionTestData& data); void GetCrossAttentionData_DiffSequenceLengths(AttentionTestData& data); void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data); void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data); -void GetSelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias(AttentionTestData& data); -void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias(AttentionTestData& data); -void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias(AttentionTestData& data); +void GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(AttentionTestData& data); +void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(AttentionTestData& data); +void GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(AttentionTestData& data); void GetCrossAttentionData_WithPastPassedInDirectly_NoMask(AttentionTestData& data); void GetCausal_EmptyPastState(std::vector& input, std::vector& output, std::vector& present); -void GetAttentionDataCutlassRelPosBias(AttentionTestData& data); +void GetAttentionDataCutlassAttnBias(AttentionTestData& data); void GetAttentionDataWithNeoXRotaryEmbedding(std::vector& input, std::vector& weights, std::vector& bias, std::vector& output); -void GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(PackedAttentionTestData& data); +void GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(PackedAttentionTestData& data); -void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias(PackedAttentionTestData& data); +void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias(PackedAttentionTestData& data); -void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias(PackedAttentionTestData& data); +void GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias(PackedAttentionTestData& data); bool SkipAttentionKernel(AttentionTestData& data, AttentionKernelType kernel_type); } // namespace test diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index f0255d7ece84e..65727828f51fb 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -31,7 +31,7 @@ static void RunMultiHeadAttentionTest( const std::vector& kv_data, // packed_kv: [batch_size, kv_sequence_length, num_heads, 2, head_size] const std::vector& qkv_data, // packed_qkv: [batch_size, sequence_length, num_heads, 3, head_size] const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size] or empty - const std::vector& rel_pos_bias_data, // relative_position_bias: [1, num_heads, sequence_length, total_sequence_length] + const std::vector& attention_bias_data, // relative_position_bias: [1, num_heads, sequence_length, total_sequence_length] const std::vector& past_key_data, // past_key: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& past_value_data, // past_value: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& present_key_data, // present_key: [batch_size, num_heads, total_sequence_length, head_size] @@ -80,7 +80,7 @@ static void RunMultiHeadAttentionTest( std::vector value_dims = {batch_size, is_static_kv ? kv_sequence_length : sequence_length, v_hidden_size}; std::vector bias_dims = {hidden_size + hidden_size + v_hidden_size}; // TODO(wy): Introduce past sequence length to avoid using kv_sequence_length. - std::vector rel_pos_bias_dims = + std::vector attention_bias_dims = {1, num_heads, sequence_length, past_key_data.size() ? sequence_length + kv_sequence_length : sequence_length}; std::vector past_key_dims = {batch_size, num_heads, kv_sequence_length, hidden_size / num_heads}; std::vector past_value_dims = past_key_dims; @@ -144,8 +144,8 @@ static void RunMultiHeadAttentionTest( tester.AddOptionalInputEdge(); } - if (rel_pos_bias_data.size()) { - tester.AddInput("relative_position_bias", rel_pos_bias_dims, ToFloat16(rel_pos_bias_data)); + if (attention_bias_data.size()) { + tester.AddInput("relative_position_bias", attention_bias_dims, ToFloat16(attention_bias_data)); } else { tester.AddOptionalInputEdge(); } @@ -208,8 +208,8 @@ static void RunMultiHeadAttentionTest( tester.AddOptionalInputEdge(); } - if (rel_pos_bias_data.size()) { - tester.AddInput("relative_position_bias", rel_pos_bias_dims, rel_pos_bias_data); + if (attention_bias_data.size()) { + tester.AddInput("relative_position_bias", attention_bias_dims, attention_bias_data); } else { tester.AddOptionalInputEdge(); } @@ -276,7 +276,7 @@ static void RunMultiHeadAttentionKernel( const std::vector& kv_data, // packed_kv: [batch_size, kv_sequence_length, num_heads, 2, head_size] const std::vector& qkv_data, // packed_qkv: [batch_size, sequence_length, num_heads, 3, head_size] const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size] - const std::vector& rel_pos_bias_data, // relative_position_bias: [1, num_heads, sequence_length, total_sequence_length] + const std::vector& attention_bias_data, // relative_position_bias: [1, num_heads, sequence_length, total_sequence_length] const std::vector& past_key_data, // past_key: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& past_value_data, // past_value: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& present_key_data, // present_key: [batch_size, num_heads, total_sequence_length, head_size] @@ -306,7 +306,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, kv_data, qkv_data, bias_data, rel_pos_bias_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); @@ -322,7 +322,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, kv_data, qkv_data, bias_data, rel_pos_bias_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); @@ -338,7 +338,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, kv_data, qkv_data, bias_data, rel_pos_bias_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); @@ -355,7 +355,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, kv_data, qkv_data, bias_data, rel_pos_bias_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); @@ -372,7 +372,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, kv_data, qkv_data, bias_data, rel_pos_bias_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); @@ -387,7 +387,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -400,7 +400,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -411,7 +411,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu kernel_type = AttentionKernelType::AttentionKernel_Default; RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -423,7 +423,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -433,7 +433,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -444,7 +444,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -454,7 +454,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu kernel_type = AttentionKernelType::AttentionKernel_Default; RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, - data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); @@ -548,17 +548,17 @@ TEST(MultiHeadAttentionTest, CrossAttentionWithPast) { } #endif -TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithRelPosBias_ForT5) { +TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithAttnBias_ForT5) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; - GetSelfAttentionData_WithPast_WithRelPosBias_ForT5(data); + GetSelfAttentionData_WithPast_WithAttnBias_ForT5(data); RunMultiHeadAttentionTests(data, true); } -TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) { +TEST(MultiHeadAttentionTest, AttentionCutlassAttnBias) { // ROCM_GTEST_SKIP("ROCm does not support cutlass"); AttentionTestData data; - GetAttentionDataCutlassRelPosBias(data); + GetAttentionDataCutlassAttnBias(data); RunMultiHeadAttentionTests(data); } @@ -575,16 +575,16 @@ TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); } -TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoRelPosBias) { +TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) { // Whisper decoder self attention with past_kv and present_kv AttentionTestData data; - GetSelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias(data); + GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(data); RunMultiHeadAttentionTests(data); - GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias(data); + GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(data); RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); - GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias(data); + GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(data); RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); } diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py index e0cfc9d0f8e25..bdb0ffc6c50db 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py @@ -502,7 +502,7 @@ def run_cross_diff_seqlen_headsize_8(): ) -def run_self_past_present_headsize_8_nomask_norelposbias(): +def run_self_past_present_headsize_8_nomask_no_attn_bias(): hidden_dim = 16 q_head_size = 8 v_head_size = 8 @@ -554,8 +554,8 @@ def create_test_data(): print("SelfAttention_Batch2_HeadSize32_PackedQKV") run_self_batch2_headsize_32_packed_qkv() - print("SelfAttention_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias") - run_self_past_present_headsize_8_nomask_norelposbias() + print("SelfAttention_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias") + run_self_past_present_headsize_8_nomask_no_attn_bias() print("CrossAttention_DiffSequenceLengths_HeadSize8") run_cross_diff_seqlen_headsize_8() diff --git a/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc index 5f811c8cf35f6..17862c0aca6fa 100644 --- a/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc @@ -32,8 +32,8 @@ namespace test { token_count, \ use_float16, \ use_scale, \ - relative_position_bias_data, \ - broadcast_relative_position_bias); + attention_bias_data, \ + broadcast_attention_bias); static void RunPackedMultiHeadAttentionTest( const std::vector& query_data, // query: [token_count, num_heads, 3, head_size] @@ -52,8 +52,8 @@ static void RunPackedMultiHeadAttentionTest( int token_count, bool use_float16, bool use_scale, - const std::vector& relative_position_bias_data, - bool broadcast_relative_position_bias) { + const std::vector& attention_bias_data, + bool broadcast_attention_bias) { int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); @@ -73,9 +73,9 @@ static void RunPackedMultiHeadAttentionTest( std::vector bias_dims = {hidden_size + hidden_size + v_hidden_size}; std::vector token_offset_dims = {batch_size, sequence_length}; std::vector cum_seq_len_dims = {batch_size + 1}; - std::vector relative_position_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; - std::vector broadcast_relative_position_bias_data_dims = {1, number_of_heads, sequence_length, sequence_length}; - auto& rel_pos_bias_dims = (broadcast_relative_position_bias ? broadcast_relative_position_bias_data_dims : relative_position_bias_data_dims); + std::vector attention_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; + std::vector broadcast_attention_bias_data_dims = {1, number_of_heads, sequence_length, sequence_length}; + auto& rel_pos_bias_dims = (broadcast_attention_bias ? broadcast_attention_bias_data_dims : attention_bias_data_dims); std::vector output_dims = {token_count, v_hidden_size}; @@ -100,10 +100,10 @@ static void RunPackedMultiHeadAttentionTest( tester.AddInput("token_offset", token_offset_dims, token_offset); tester.AddInput("cumulative_sequence_length", cum_seq_len_dims, cumulative_sequence_length); - if (relative_position_bias_data.size() > 0) { - tester.AddInput("relative_position_bias", + if (attention_bias_data.size() > 0) { + tester.AddInput("attention_bias", rel_pos_bias_dims, - ToFloat16(relative_position_bias_data)); + ToFloat16(attention_bias_data)); } tester.AddOutput("output", output_dims, ToFloat16(output_data)); @@ -127,8 +127,8 @@ static void RunPackedMultiHeadAttentionTest( tester.AddInput("token_offset", token_offset_dims, token_offset); tester.AddInput("cumulative_sequence_length", cum_seq_len_dims, cumulative_sequence_length); - if (relative_position_bias_data.size() > 0) { - tester.AddInput("relative_position_bias", rel_pos_bias_dims, relative_position_bias_data); + if (attention_bias_data.size() > 0) { + tester.AddInput("attention_bias", rel_pos_bias_dims, attention_bias_data); } tester.AddOutput("output", output_dims, output_data); @@ -157,8 +157,8 @@ static void RunPackedMultiHeadAttentionTest( int number_of_heads, int token_count, AttentionKernelType kernel_type, - const std::vector& relative_position_bias_data = {}, - bool broadcast_relative_position_bias = false) { + const std::vector& attention_bias_data = {}, + bool broadcast_attention_bias = false) { if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ @@ -310,9 +310,9 @@ TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_NoBias_trt) { AttentionKernelType::AttentionKernel_TrtFusedAttention); } -TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_RelPosBias_cutlass) { +TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_AttnBias_cutlass) { AttentionTestData data; - GetAttentionDataCutlassRelPosBias(data); + GetAttentionDataCutlassAttnBias(data); std::vector token_offset{0, 1, 2, 3, 4, 5, 6, 7}; std::vector cum_seq_len{0, 8}; @@ -331,13 +331,13 @@ TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_RelPosBias_cutlass) { data.num_heads, data.batch_size * data.sequence_length, AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention, - data.rel_pos_bias_data, - data.broadcast_rel_pos_bias); + data.attention_bias_data, + data.broadcast_attention_bias); } -TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_RelPosBias_unfused) { +TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_AttnBias_unfused) { AttentionTestData data; - GetAttentionDataCutlassRelPosBias(data); + GetAttentionDataCutlassAttnBias(data); std::vector token_offset{0, 1, 2, 3, 4, 5, 6, 7}; std::vector cum_seq_len{0, 8}; @@ -356,13 +356,13 @@ TEST(PackedMultiHeadAttentionTest, Q_K_V_NoPadding_Bias_RelPosBias_unfused) { data.num_heads, data.batch_size * data.sequence_length, AttentionKernelType::AttentionKernel_Unfused, - data.rel_pos_bias_data, - data.broadcast_rel_pos_bias); + data.attention_bias_data, + data.broadcast_attention_bias); } TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_trt) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -384,7 +384,7 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_trt) { TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_cutlass) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -408,7 +408,7 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_cutlass) { TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_FlashAttention) { if (HasCudaEnvironment(800)) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -432,7 +432,7 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_FlashAttention) { TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_unfused) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -452,9 +452,9 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_unfused) { AttentionKernelType::AttentionKernel_Unfused); } -TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_RelPosBias) { +TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_AttnBias) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -472,13 +472,13 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_RelPosBias) { data.num_heads, data.token_count, AttentionKernelType::AttentionKernel_Default, - data.rel_pos_bias_data, - data.broadcast_rel_pos_bias); + data.attention_bias_data, + data.broadcast_attention_bias); } -TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastRelPosBias_cutlass) { +TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastAttnBias_cutlass) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -496,13 +496,13 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastRelPosBias_ data.num_heads, data.token_count, AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention, - data.rel_pos_bias_data, - data.broadcast_rel_pos_bias); + data.attention_bias_data, + data.broadcast_attention_bias); } -TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastRelPosBias_unfused) { +TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastAttnBias_unfused) { PackedAttentionTestData data; - GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias(data); + GetPackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias(data); std::vector empty_data = {}; RunPackedMultiHeadAttentionTest( @@ -520,8 +520,8 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_BroadcastRelPosBias_ data.num_heads, data.token_count, AttentionKernelType::AttentionKernel_Unfused, - data.rel_pos_bias_data, - data.broadcast_rel_pos_bias); + data.attention_bias_data, + data.broadcast_attention_bias); } } // namespace test diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 0c52ee690af82..791fff2a8969d 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -89,7 +89,7 @@ def __init__( past_sequence_length: int = 0, kv_sequence_length=None, max_cache_sequence_length=None, - softmax_scale: float = 0.0, + scale: float = 0.0, provider="CPUExecutionProvider", device: Optional[torch.device] = None, enable_cuda_graph: bool = False, @@ -99,7 +99,10 @@ def __init__( share_past_present_buffer: bool = False, input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH, verbose: bool = False, - has_bias: bool = False, + has_bias: bool = False, # bias for input projection + has_attn_bias: bool = False, # bias added before softmax. For example,relative position bias. + broadcast_attn_bias_dim_0: bool = False, # broadcast attention bias dimension 0 + broadcast_attn_bias_dim_1: bool = False, # broadcast attention bias dimension 1 mask_format: int = AttentionMaskFormat.Mask_None, ): self.operator = "MultiHeadAttention" @@ -111,7 +114,7 @@ def __init__( self.num_heads = num_heads self.head_size = head_size self.causal = causal - self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5)) + self.scale = scale or (1.0 / (head_size**0.5)) # Support the case that there is no past but need present output (for prompt case). self.has_past_input = has_past_input @@ -151,6 +154,22 @@ def __init__( self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H self.verbose = verbose self.has_bias = has_bias + self.has_attn_bias = has_attn_bias + self.broadcast_attn_bias_dim_0 = broadcast_attn_bias_dim_0 + self.broadcast_attn_bias_dim_1 = broadcast_attn_bias_dim_1 + + assert mask_format in [ + AttentionMaskFormat.Mask_None, + AttentionMaskFormat.Mask_1D_Key_SeqLen, + AttentionMaskFormat.Mask_2D_Key_PaddingMask, + ] + self.mask_format = mask_format + + # mask_index_q and mask_index_kv will be updated in random_inputs() if mask_format is not Mask_None. + self.mask_index_kv = torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length + self.mask_index_q = ( + torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.total_sequence_length + ) assert mask_format in [ AttentionMaskFormat.Mask_None, @@ -171,11 +190,14 @@ def __repr__(self): f"num_heads={self.num_heads}, head_size={self.head_size}, " f"kv_sequence_length={self.kv_sequence_length}, past_sequence_length={self.past_sequence_length}, " f"max_cache_sequence_length={self.max_cache_sequence_length}," - f"causal={self.causal}), softmax_scale={self.softmax_scale}, use_kv_cache={self.use_kv_cache}, " + f"causal={self.causal}), scale={self.scale}, use_kv_cache={self.use_kv_cache}, " f"share_past_present_buffer={self.share_past_present_buffer}, " f"provider={self.provider}, device={self.device}, enable_cuda_graph={self.enable_cuda_graph}, " f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}, " - f"has_bias={self.has_bias}, mask_format={self.mask_format}" + f"has_bias={self.has_bias}, mask_format={self.mask_format}, " + f"has_attn_bias={self.has_attn_bias}, " + f"broadcast_attn_bias_dim_0={self.broadcast_attn_bias_dim_0}, " + f"broadcast_attn_bias_dim_1={self.broadcast_attn_bias_dim_1}, " ) def shape_dict(self, input_format=None): @@ -235,6 +257,14 @@ def shape_dict(self, input_format=None): else: assert self.mask_format == AttentionMaskFormat.Mask_None + if self.has_attn_bias: + shapes["attn_bias"] = ( + 1 if self.broadcast_attn_bias_dim_0 else self.batch_size, + 1 if self.broadcast_attn_bias_dim_1 else self.num_heads, + self.sequence_length, + self.total_sequence_length, + ) + return shapes def symbolic_shape_dict(self, input_format=None): @@ -288,12 +318,15 @@ def symbolic_shape_dict(self, input_format=None): shapes["bias"] = (3 * self.num_heads * self.head_size,) if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen: - shapes["mask"] = (self.batch_size,) + shapes["mask"] = ("batch_size",) elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask: - shapes["mask"] = (self.batch_size, "total_sequence_length") + shapes["mask"] = ("batch_size", "total_sequence_length") else: assert self.mask_format == AttentionMaskFormat.Mask_None + if self.has_attn_bias: + shapes["attn_bias"] = ("batch_size_or_1", "num_heads_or_1", "sequence_length", "total_sequence_length") + return shapes def right_side_padding_masks(self): @@ -406,6 +439,19 @@ def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False): if mask is not None: feeds = {**feeds, "mask": mask.to(dtype=torch.int32)} # mask is int32 (not bool) for MultiHeadAttention op. + if self.has_attn_bias: + attn_bias = torch.empty( + ( + 1 if self.broadcast_attn_bias_dim_0 else self.batch_size, + 1 if self.broadcast_attn_bias_dim_1 else self.num_heads, + self.sequence_length, + self.total_sequence_length, + ), + device=self.device, + dtype=dtype, + ).normal_(mean=0, std=0.1) + feeds["attn_bias"] = attn_bias + return feeds def get_input_output_names(self): @@ -425,6 +471,9 @@ def get_input_output_names(self): if self.mask_format != AttentionMaskFormat.Mask_None: inputs = [*inputs, "mask"] + if self.has_attn_bias: + inputs = [*inputs, "attn_bias"] + if self.has_past_input: inputs = [*inputs, "past_key", "past_value"] @@ -435,7 +484,7 @@ def get_input_output_names(self): def fill_optional_mha_inputs(input_names): - inputs = ["query", "key", "value", "bias", "mask", "relative_position_bias", "past_key", "past_value"] + inputs = ["query", "key", "value", "bias", "mask", "attn_bias", "past_key", "past_value"] # Remove optional inputs that are not in input_names with empty string inputs_with_optional = [input if input in input_names else "" for input in inputs] @@ -459,7 +508,7 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use "MultiHeadAttention_0", num_heads=config.num_heads, unidirectional=int(config.causal), - scale=config.softmax_scale, + scale=config.scale, mask_filter_value=float("-inf"), domain="com.microsoft", ), @@ -725,7 +774,12 @@ def run_tflops_test( # flash attention is available for sm >= 80 sm = get_compute_capability() if sm >= 80: - backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION] + backends = [ + SdpaKernel.DEFAULT, + SdpaKernel.FLASH_ATTENTION, + SdpaKernel.EFFICIENT_ATTENTION, + SdpaKernel.CUDNN_FLASH_ATTENTION, + ] else: backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION] else: @@ -786,7 +840,11 @@ def run_tflops_test( input_dict = config.random_inputs() # warm up session - _ = measure_latency(session, input_dict) + try: + _ = measure_latency(session, input_dict) + except Exception as e: + print(f"Failed to run {kernel=} for {config=}. Exception: {e}") + continue latency_list = [] for _ in range(repeats): @@ -1013,7 +1071,7 @@ def benchmark( head_size=head_size, causal=False, past_sequence_length=0, - kv_sequence_length=sequence_length if input_format == InputFormats.get_name_list()[-1] else None, + kv_sequence_length=sequence_length if input_format == "Q,K',V'" else None, max_cache_sequence_length=max_seq_len, provider="CUDAExecutionProvider", enable_cuda_graph=False, diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 5948f8b1ccfc1..f6a837ddf829a 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -68,13 +68,26 @@ def get_bias_support(format: InputFormats): raise RuntimeError(f"Unknown format: {format}") +def get_atten_bias_support(): + atten_bias_options = [ + # (has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1) + (False, False, False), + (True, False, False), # [b, n, s_q, s_kv] + (True, True, False), # [1, n, s_q, s_kv] + # (True, False, True), # [b, 1, s_q, s_kv] + # (True, True, True), # [1, 1, s_q, s_kv] + ] + return atten_bias_options + + def attention_reference( head_size: int, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, + attn_bias: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, verbose: bool = False, ) -> torch.Tensor: """Reference implementation of SDPA @@ -84,8 +97,9 @@ def attention_reference( query (torch.Tensor): query in BNSH format key (torch.Tensor): key in BNSH format value (torch.Tensor): value in BNSH format - scale (Optional[float], optional): scale applied before softmax. Defaults to None. - mask (Optional[torch.Tensor], optional): attention mask. Defaults to None. + scale (Optional[float], optional): scale applied on QxK'. Defaults to None. + attn_bias : attention bias tensor added before softmax. Defaults to None. + masks : attention masks. Defaults to None. Returns: torch.Tensor: result of SDPA @@ -100,25 +114,30 @@ def attention_reference( if verbose: torch.set_printoptions(precision=6, linewidth=200, sci_mode=False) - print("query(SDPA)", query) - print("key(SDPA)", key) - print("value(SDPA)", value) + print("query(ref)", query) + print("key(ref)", key) + print("value(ref)", value) if mask is not None: print("mask", mask) # Apply multi-head attention. attn = torch.einsum("bhmd,bhnd->bhmn", query, key).float() * scale if verbose: - print("QK(SDPA)", attn) + print("QK(ref)", attn) + + if attn_bias is not None: + attn = attn + attn_bias + if verbose: + print("QK+AttnBias(ref)", attn) if mask is not None: attn = attn.masked_fill((1 - mask.int()).bool(), float("-inf")) if verbose: - print("masked QK(SDPA)", attn) + print("masked QK(ref)", attn) attn = attn.softmax(-1) if verbose: - print("Softmax(SDPA)", attn) + print("Softmax(ref)", attn) attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value) @@ -128,7 +147,7 @@ def attention_reference( torch.cuda.synchronize() if verbose: - print("result(SDPA)", result) + print("result(ref)", result) return result @@ -141,6 +160,7 @@ def mha_with_past_reference( k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None, + attn_bias: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, ): assert config.kv_sequence_length == config.sequence_length @@ -157,7 +177,7 @@ def mha_with_past_reference( present_k = torch.cat((past_k, k), dim=2) if past_k is not None else k present_v = torch.cat((past_v, v), dim=2) if past_v is not None else v - out = attention_reference(config.head_size, q, present_k, present_v, scale=scale, mask=mask) + out = attention_reference(config.head_size, q, present_k, present_v, scale=scale, attn_bias=attn_bias, mask=mask) return out, present_k, present_v @@ -185,6 +205,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): AttentionMaskFormat.Mask_1D_Key_SeqLen, AttentionMaskFormat.Mask_2D_Key_PaddingMask, ] + atten_bias_options = get_atten_bias_support() device, dtype, formats = get_provider_support_info(provider, False) if comprehensive: @@ -197,25 +218,33 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): for causal in [True, False]: for mask_format in mask_formats: for has_bias in get_bias_support(format): - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - has_bias=has_bias, - mask_format=mask_format, - ) - yield config + for ( + has_attn_bias, + broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1, + ) in atten_bias_options: + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + mask_format=mask_format, + has_attn_bias=has_attn_bias, + broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -224,6 +253,9 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] mask_format = mask_formats[i % len(mask_formats)] + has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[ + i % len(atten_bias_options) + ] for causal in [True, False]: for format in formats: for has_bias in get_bias_support(format): @@ -244,6 +276,9 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): input_format=format, has_bias=has_bias, mask_format=mask_format, + has_attn_bias=has_attn_bias, + broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, ) yield config @@ -264,6 +299,8 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): AttentionMaskFormat.Mask_2D_Key_PaddingMask, ] + atten_bias_options = get_atten_bias_support() + if comprehensive: sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory for batch_size in batch_sizes: @@ -275,28 +312,36 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): for has_past_input in [True, False]: for mask_format in mask_formats: for has_bias in get_bias_support(format): - sequence_length = 1 if has_past_input else past_sequence_length - past_seq_len = past_sequence_length if has_past_input else 0 - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_seq_len, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - has_past_input=has_past_input, - share_past_present_buffer=False, - input_format=format, - has_bias=has_bias, - mask_format=mask_format, - ) - yield config + for ( + has_attn_bias, + broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1, + ) in atten_bias_options: + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + mask_format=mask_format, + has_attn_bias=has_attn_bias, + broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -305,6 +350,9 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] mask_format = mask_formats[i % len(mask_formats)] + has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[ + i % len(atten_bias_options) + ] for causal in [True, False]: for format in formats: for has_past_input in [True, False]: @@ -329,6 +377,9 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): input_format=format, has_bias=has_bias, mask_format=mask_format, + has_attn_bias=has_attn_bias, + broadcast_attn_bias_dim_0=broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=broadcast_attn_bias_dim_1, ) yield config @@ -470,6 +521,10 @@ def parity_check_mha( k = k + bias_k v = v + bias_v + attn_bias = None + if config.has_attn_bias: + attn_bias = ref_inputs["attn_bias"] + q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -480,11 +535,13 @@ def parity_check_mha( if config.use_kv_cache: past_k = ref_inputs.get("past_key", None) past_v = ref_inputs.get("past_value", None) - out_ref, k_cache, v_cache = mha_with_past_reference(config, past_k, past_v, q, k, v, mask=mask) + out_ref, k_cache, v_cache = mha_with_past_reference( + config, past_k, past_v, q, k, v, scale=config.scale, attn_bias=attn_bias, mask=mask + ) else: - out_ref = attention_reference(config.head_size, q, k, v, mask=mask) + out_ref = attention_reference(config.head_size, q, k, v, scale=config.scale, attn_bias=attn_bias, mask=mask) - # Fill zeros for the padded kens for comparison. + # Fill zeros for the padded tokens for comparison. if config.mask_index_q is not None: for i, m in enumerate(config.mask_index_q): out[i, m:, :, :] = 0 @@ -584,35 +641,69 @@ def check_parity_with_config(i: int): ) # Create reference inputs + old_format = config.input_format config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH ref_inputs = test_inputs[i]["ref_inputs"] if verbose: print(f"Thread {i} ref inputs: {ref_inputs}") - q = ( - ref_inputs["query"] - .reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) - k = ( - ref_inputs["key"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) + + q = ref_inputs["query"].reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + k = ref_inputs["key"].reshape( + (config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size) ) - v = ( - ref_inputs["value"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) + v = ref_inputs["value"].reshape( + (config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size) ) + if "bias" in ref_inputs: + bias = ref_inputs["bias"] + bias = bias.reshape((3, config.num_heads, config.head_size)) + bias_q = bias[0, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_k = bias[1, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_v = bias[2, :, :].reshape(1, 1, config.num_heads, config.head_size) + q = q + bias_q + k = k + bias_k + v = v + bias_v + + attn_bias = None + if config.has_attn_bias: + attn_bias = ref_inputs["attn_bias"] + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + mask = merge_padding_and_causal_masks(config) k_cache = None v_cache = None if config.use_kv_cache: - past_k = ref_inputs["past_key"] - past_v = ref_inputs["past_value"] - out_ref, k_cache, v_cache = mha_with_past_reference(config, past_k, past_v, q, k, v, mask=mask) + past_k = ref_inputs.get("past_key", None) + past_v = ref_inputs.get("past_value", None) + out_ref, k_cache, v_cache = mha_with_past_reference( + config, past_k, past_v, q, k, v, scale=config.scale, attn_bias=attn_bias, mask=mask + ) else: - out_ref = attention_reference(config.head_size, q, k, v, mask=mask) + out_ref = attention_reference(config.head_size, q, k, v, scale=config.scale, attn_bias=attn_bias, mask=mask) + + # Fill zeros for the padded tokens for comparison. + if config.mask_index_q is not None: + for i, m in enumerate(config.mask_index_q): + out[i, m:, :, :] = 0 + out_ref[i, m:, :, :] = 0 + + if config.mask_index_kv is not None and config.use_kv_cache: + assert k_cache is not None + assert v_cache is not None + present_key = ort_outputs[1] + present_value = ort_outputs[2] + for i, n in enumerate(config.mask_index_kv): + k_cache[i, :, n:, :] = 0 + present_key[i, :, n:, :] = 0 + v_cache[i, :, n:, :] = 0 + present_value[i, :, n:, :] = 0 + + # Restore the input format so that it shows up in the error message correctly. + config.input_format = old_format try: numpy.testing.assert_allclose( diff --git a/onnxruntime/test/testdata/attention/attention_test_data.txt b/onnxruntime/test/testdata/attention/attention_test_data.txt index c52dd4ef1988b..7c60efea1f0f6 100644 --- a/onnxruntime/test/testdata/attention/attention_test_data.txt +++ b/onnxruntime/test/testdata/attention/attention_test_data.txt @@ -2812,26 +2812,26 @@ name:CrossAttentionDataWithPast.fp32_output_data 0.4291,0.5276,0.4818,0.4645,0.4768,0.4083,0.3377,0.4315, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.query_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.query_data 0.00403503,0.08716156,-0.0358175,-0.08171791, 0.48912194,-0.22679007,-0.09093101,-0.5939322, 0.00878838,0.03355761,-0.08080226,-0.06677517, 0.55038965,-0.2720567,-0.12977877,-0.634123, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.key_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.key_data 0.2808786,0.10041683,0.15880886,0.45283064, 0.39884242,0.12596075,0.4198916,-0.0651141, 0.31678027,0.11010794,0.21594375,0.4975329, 0.436772,0.20940652,0.44072092,-0.05601776, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.value_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.value_data 0.26421773,-0.16541699,-0.0599675,0.27200517, -0.1074627,-0.4493224,-0.03694462,0.17997989, 0.27960598,-0.16643806,-0.07019104,0.29006317, -0.11640988,-0.47876123,-0.01979145,0.11468418, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.rel_pos_bias_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.attention_bias_data 0.4781123,0.82420444,0.654424,0.3995186,0.5482078, 0.55570245,0.4216576,0.46001542,0.67183703,0.41973996, @@ -2839,7 +2839,7 @@ name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.rel_pos_bias_data 0.5460559,0.31994605,0.5470492,0.5433419,0.60349935, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_key_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.past_key_data 0.34734827,0.5592256,0.5333037,0.5122027, 0.5940516,0.44744077,0.43128848,0.55360645, 0.57874715,0.29512063,0.2780432,0.4693917, @@ -2849,7 +2849,7 @@ name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_key_data 0.5352153,0.5157861,0.39744973,0.5441864, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_value_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.past_value_data 0.48998538,0.5493853,0.556647,0.7011929, 0.543909,0.5630743,0.5087797,0.3901024, 0.53116417,0.4086225,0.5320247,0.5145377, @@ -2858,12 +2858,12 @@ name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.past_value_data 0.52980417,0.5243695,0.6046111,0.53555113, 0.44936907,0.6010697,0.38031512,0.427301, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.fp32_output_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.fp32_output_data 0.4358,0.2708,0.3201,0.4347,0.1886,0.0845,0.2479,0.3289, 0.4157,0.2247,0.2826,0.4321,0.1874,0.1021,0.2427,0.3305, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_key_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.present_key_data 0.3473,0.5592,0.5333,0.5122, 0.5941,0.4474,0.4313,0.5536, 0.5787,0.2951,0.2780,0.4694, @@ -2877,7 +2877,7 @@ name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_key_data 0.4368,0.2094,0.4407,-0.0560, === -name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_value_data +name:SelfAttentionData_WithPast_WithAttnBias_ForT5.present_value_data 0.4900,0.5494,0.5566,0.7012, 0.5439,0.5631,0.5088,0.3901, 0.5312,0.4086,0.5320,0.5145, @@ -2891,7 +2891,7 @@ name:SelfAttentionData_WithPast_WithRelPosBias_ForT5.present_value_data -0.1164,-0.4788,-0.0198,0.1147, === -name:AttentionDataCutlassRelPosBias.query_data +name:AttentionDataCutlassAttnBias.query_data -0.029273793,0.079709493,0.064531095,0.24270254,-0.28326464,0.20984903,-0.10173888,0.18373983, 0.089472905,-0.0063416883,-0.049477674,0.36512995,-0.23620239,0.1464397,0.068258412,0.31627196, @@ -2909,7 +2909,7 @@ name:AttentionDataCutlassRelPosBias.query_data 0.002485469,0.029660821,-0.043821491,0.3892332,-0.26994205,0.14530671,0.12950704,0.36185294, === -name:AttentionDataCutlassRelPosBias.key_data +name:AttentionDataCutlassAttnBias.key_data -0.32538497,0.34121913,-0.18170178,-0.015152611,0.20429322,0.25979176,0.21269324,0.0025638193, -0.24246037,0.21112341,-0.36959589,-0.16091451,0.24183474,0.18856162,0.094487116,-0.3053959, @@ -2921,7 +2921,7 @@ name:AttentionDataCutlassRelPosBias.key_data -0.35736683,0.29276621,-0.4217523,-0.20031664,0.33148992,0.26928401,0.19360018,-0.39494509, -0.28043351,0.24279942,-0.29154932,-0.13657911,0.31932494,0.3500579,0.027172565,-0.19327414, === -name:AttentionDataCutlassRelPosBias.value_data +name:AttentionDataCutlassAttnBias.value_data 0.56916672,-0.2443777,0.47111356,-0.52134115,0.010381341,0.0696759,-0.071910433,-0.35201436, 0.70809275,-0.24479815,0.41633749,-0.34744334,-0.0044222325,0.25929695,-0.087832771,-0.281232, 0.90039468,-0.28931504,0.56394172,-0.43948689,-0.05856207,0.33713666,-0.10320446,-0.38833332, @@ -2931,7 +2931,7 @@ name:AttentionDataCutlassRelPosBias.value_data 0.90039468,-0.28931504,0.56394172,-0.43948689,-0.05856207,0.33713666,-0.10320446,-0.38833332, 0.76054728,-0.29080144,0.50414616,-0.42371163,-0.047198489,0.31959397,-0.22683662,-0.30321664, === -name:AttentionDataCutlassRelPosBias.bias_data +name:AttentionDataCutlassAttnBias.bias_data -0.38124341,0.02696526,-0.11914945,-0.43795273, 0.04772711,-0.03419551,-0.30606642,0.42656231, -0.25891554,0.13431972,0.22861153,0.06360734, @@ -2939,7 +2939,7 @@ name:AttentionDataCutlassRelPosBias.bias_data 0.27079183,0.42074734,-0.40314156,-0.43726659, -0.40546918,0.06927037,0.16979086,0.41458064, === -name:AttentionDataCutlassRelPosBias.rel_pos_bias_data +name:AttentionDataCutlassAttnBias.attention_bias_data -10.808288,-10.887209,7.8799553,-4.6565766,-1.6700006,-0.033962168,7.4929152,10.944146,8.640254,-18.862164,-3.1202927,-6.3049207,3.4508536,11.722519,3.3550568,-5.4888172, -2.0828252,-13.241742,2.9868939,1.4455698,-15.262972,-10.457437,-8.4519463,-4.4281874,10.212368,-0.28622282,12.087646,6.5218501,8.1785011,13.985523,-8.2068987,5.4260745, -10.808288,-10.887209,7.8799553,-4.6565766,-1.6700006,-0.033962168,7.4929152,10.944146,8.640254,-18.862164,-3.1202927,-6.3049207,3.4508536,11.722519,3.3550568,-5.4888172, @@ -2949,7 +2949,7 @@ name:AttentionDataCutlassRelPosBias.rel_pos_bias_data -10.808288,-10.887209,7.8799553,-4.6565766,-1.6700006,-0.033962168,7.4929152,10.944146,8.640254,-18.862164,-3.1202927,-6.3049207,3.4508536,11.722519,3.3550568,-5.4888172, -2.0828252,-13.241742,2.9868939,1.4455698,-15.262972,-10.457437,-8.4519463,-4.4281874,10.212368,-0.28622282,12.087646,6.5218501,8.1785011,13.985523,-8.2068987,5.4260745, === -name:AttentionDataCutlassRelPosBias.fp16_output_data +name:AttentionDataCutlassAttnBias.fp16_output_data 1.0419922,0.13000488,0.10528564,-0.86230469,-0.45336914,0.39013672,-0.048858643,0.10571289, 0.97265625,0.17590332,0.015625,-0.79248047,-0.40917969,0.31933594,0.082763672,0.12976074, 1.1455078,0.13134766,0.15014648,-0.87451172,-0.46142578,0.40161133,0.04309082,0.042663574, @@ -3095,61 +3095,61 @@ name:CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.present_value_data 1.20772719,-0.99407929,-0.15339416,0.54562038,1.29705775,-0.28651321,-0.90150839,-1.09473300, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.query_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.query_data 0.19646919,-0.21386067,-0.27314855,0.05131477,0.21946897,-0.07689354,0.4807642,0.18482974,-0.0190681,-0.10788248,-0.15682198,0.22904971,-0.06142776,-0.4403221,-0.10195574,0.23799541, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.key_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.key_data -0.31750827,-0.32454824,0.03155137,0.03182759,0.13440096,0.34943179,0.22445532,0.11102351,0.22244338,-0.17704109,-0.13821134,-0.27173677,-0.20628595,0.13097612,-0.40789506,-0.06629883, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.value_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.value_data -0.06913724,-0.0063149,-0.07416971,-0.18773878,-0.07364869,0.39338916,0.44416002,0.00183668,0.12395295,-0.3843816,-0.18271452,-0.08517379,0.36630916,-0.24954463,-0.01696574,0.48555979, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.bias_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.bias_data 0.01948512,0.11289453,-0.37937133,0.3263408,0.10306013,0.04506801,-0.15723617,-0.19587921,-0.08297779,0.18130077,0.37545684,0.01042234,0.16931378,0.08593655,0.1249035,0.17468905,0.34234244,-0.41680501,0.26368284,-0.25633363,-0.30577704,0.07245696,-0.40428748,0.38532683, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.past_key_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.past_key_data 0.12724897,0.22341636,-0.48387079,0.09443188,0.05678519,-0.34104036,-0.34692948,0.19552953,-0.18123357,0.1919703,0.05438325,-0.11104943,0.42513249,0.34167,-0.14260243,-0.45640854, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.past_value_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.past_value_data -0.19523193,-0.10181432,0.20495883,0.49535848,-0.14408513,0.26254781,0.09317692,0.1917018,-0.34887255,-0.10112371,-0.2591441,-0.15654399,0.01312815,0.16662455,-0.39409151,-0.36910505, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.fp32_output_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.fp32_output_data -0.00033577532,-0.23549549,0.19853255,0.10450245,-0.26995566,0.37128073,0.064667389,0.29624334,0.040147364,-0.43521237,-0.096833363,-0.24481347,0.037364807,-0.0091082826,-0.40797871,0.26487666, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.present_key_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.present_key_data 0.12724897,0.22341636,-0.4838708,0.094431877,-0.40048605,-0.14324747,0.4070082,0.042249933, 0.056785189,-0.34104037,-0.34692949,0.19552954,0.30371475,0.43536833,0.34935883,0.28571257, -0.18123357,0.1919703,0.054383252,-0.11104943,0.1394656,0.0042596906,0.2372455,-0.26131442, 0.42513248,0.34167001,-0.14260243,-0.45640853,-0.03697218,0.21691267,-0.28299156,0.10839023, === -name:SelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias.present_value_data +name:SelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias.present_value_data -0.19523193,-0.10181432,0.20495883,0.49535847,0.27320519,-0.4231199,0.18951313,-0.4440724, -0.14408512,0.26254782,0.093176924,0.1917018,-0.37942573,0.46584612,0.039872527,0.38716352, -0.34887254,-0.10112371,-0.2591441,-0.15654399,0.46629539,-0.80118656,0.08096832,-0.34150741, 0.01312815,0.16662455,-0.39409152,-0.36910504,0.060532123,-0.17708766,-0.42125323,0.87088662, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.query_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.query_data 1.29534733,2.14051294,1.09895217,1.39164531,-0.01471180,-1.40148544,-0.50825417,0.26134527, -0.70491123,0.63738143,2.13708138,0.05667466,-0.44220763,0.85254443,2.00844359,-1.23413038, -0.08030051,-1.25450790,-0.89664006,-0.69433510,0.20943037,1.41880298,1.42875051,0.79920006, 1.57896936,-1.13204634,-0.61002654,0.43365243,0.22888106,-0.38688308,-0.45924744,0.99473029, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.key_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.key_data 0.37680483,0.15317714,0.05767500,0.37780648,-2.27755547,0.89294612,-0.85582626,0.54963046, 1.67390800,-1.06330085,-2.99566054,0.68927419,1.66056263,-0.77022851,0.15417719,0.94860524, -1.84928346,-0.52135336,0.70491475,0.37400877,0.55338752,0.52915680,0.52876079,-0.55780333, -1.49814773,0.18675917,0.31246936,-1.32707596,0.42132780,-1.69121027,0.20342645,-0.34370381, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.value_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.value_data 0.60890561,-0.88021755,1.63002241,0.86171651,1.80559230,1.26110435,-0.97890180,-1.60215497, -0.79229754,1.07830989,-0.85298145,2.76264572,0.01659799,-1.49499071,0.85316724,-2.56763911, 0.53017867,1.31909978,-1.10940945,0.68858552,-1.07115889,-2.34016919,0.48310637,-0.05351824, -0.08850761,-0.56362265,0.05224326,-2.47377181,0.44249821,-0.10389519,-0.46113095,2.81619215, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.bias_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.bias_data -0.38124341,0.02696526,-0.11914945,-0.43795273,-0.34948170,-0.19608477,0.19725692,0.39987487, 0.04772711,-0.03419551,-0.30606642,0.42656231,-0.23178342,-0.13692456,-0.04889601,0.48739988, 0.27079183,0.42074734,-0.40314156,-0.43726659,0.27376485,-0.38174152,-0.43700469,0.38040614, @@ -3158,28 +3158,28 @@ name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.bias_dat 0.34785229,0.00531715,-0.35168743,-0.11641458,0.39196932,0.44535065,0.43545735,0.15593112, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.past_key_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.past_key_data -1.00657940,-0.46509427,-1.65118766,-0.17705369,1.71204090,0.53921354,-1.67056096,0.42517155, -2.00129080,1.26244307,0.28864837,1.38792157,-0.59647840,-1.18904924,0.58950418,-2.26774645, 1.88496518,0.59231639,0.33360308,-1.23532701,0.10543400,-1.77481365,-0.79397631,-0.22495472, -0.26800078,-0.20456636,1.43141091,1.55566478,-0.22702518,1.75312757,-1.29037595,-0.95538902, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.past_value_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.past_value_data 3.18056512,0.13370860,-2.20253444,2.30826044,0.86762893,-1.91499686,2.18277764,0.53384149, -0.43230706,0.49148068,-0.29957789,-3.56583714,-1.46747136,-0.40299624,1.78018796,2.84104395, -0.68692255,1.25688624,-0.42734757,-1.03185725,0.47858545,1.18466282,-1.06095874,-0.63918531, 1.41408277,0.74389833,0.89590931,1.06388271,1.29734015,0.42640167,-0.99740052,-2.79366398, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.fp32_output_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.fp32_output_data 0.72723210,-0.54989153,1.22711349,1.26993895,1.78235006,1.12648177,-0.42493403,-1.27518260, -0.43240935,0.49647018,-0.30720428,-3.51349354,-1.45166361,-0.40844491,1.77604592,2.79678369, 0.25752395,1.53741217,-1.08321750,0.69643497,-0.78710371,-1.68901348,0.51954043,-0.00401744, 1.11207914,0.40332735,0.58328331,0.10821819,1.17628312,0.40418532,-0.74326056,-1.28571272, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_key_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.present_key_data -1.00657940,-0.46509427,-1.65118766,-0.17705369,1.71204090,0.53921354,-1.67056096,0.42517155, 0.64759666,0.57392448,-0.34546655,-0.05946010,-2.00379062,0.51120460,-1.29283094,0.93003660, -2.00129080,1.26244307,0.28864837,1.38792157,-0.59647840,-1.18904924,0.58950418,-2.26774645, @@ -3190,7 +3190,7 @@ name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_ -1.90361691,0.25602955,0.48226023,-0.91249532,0.49253359,-1.77176893,0.32437757,-0.62359041, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_value_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias.present_value_data 3.18056512,0.13370860,-2.20253444,2.30826044,0.86762893,-1.91499686,2.18277764,0.53384149, 0.50323355,-0.61230683,1.54025340,1.17513633,1.86586761,1.40418029,-0.66302794,-1.44035339, -0.43230706,0.49148068,-0.29957789,-3.56583714,-1.46747136,-0.40299624,1.78018796,2.84104395, @@ -3201,28 +3201,28 @@ name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias.present_ 0.25934470,-0.55830550,-0.29944417,-2.59018636,0.83446753,0.34145546,-0.02567360,2.97212315, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.past_key_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.past_key_data -1.27737117,-0.88584161,-1.24804604,0.26021290,1.43827605,0.92095506,-1.23355627,0.04476542, -1.59582162,1.19317269,0.11885749,0.97334087,-0.66768420,-1.10849059,0.46855307,-1.98785996, 1.61417341,0.17156902,0.73674464,-0.79806042,-0.16833085,-1.39307213,-0.35697165,-0.60536087, 0.13746840,-0.27383673,1.26162004,1.14108407,-0.29823098,1.83368623,-1.41132712,-0.67550242, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.past_value_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.past_value_data 3.28623724,-0.13420212,-2.11276555,1.99484074,0.80735362,-2.05807281,1.86690378,0.37204000, -0.78015935,0.48616353,0.05210955,-3.44942260,-1.85944068,-0.84834689,1.34473062,2.68511271, -0.58125055,0.98897558,-0.33757859,-1.34527707,0.41831014,1.04158688,-1.37683260,-0.80098683, 1.06623054,0.73858118,1.24759674,1.18029726,0.90537083,-0.01894896,-1.43285787,-2.94959521, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.fp32_output_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.fp32_output_data 0.89556247,-0.80034304,1.22928894,0.98303795,1.69871271,0.90572613,-0.67420667,-1.39078152, -0.78021139,0.48869953,0.04823331,-3.42281842,-1.85140634,-0.85111630,1.34262550,2.66261697, 0.34449580,1.26394701,-0.98046219,0.34879467,-0.82231814,-1.77519011,0.17237240,-0.17839541, 0.72679031,0.35579273,0.89621741,0.10616791,0.76930743,-0.04391927,-1.14721453,-1.25471735, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.present_key_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.present_key_data -1.27737117,-0.88584161,-1.24804604,0.26021290,1.43827605,0.92095506,-1.23355627,0.04476542, 0.37680483,0.15317714,0.05767500,0.37780648,-2.27755547,0.89294612,-0.85582626,0.54963046, -1.59582162,1.19317269,0.11885749,0.97334087,-0.66768420,-1.10849059,0.46855307,-1.98785996, @@ -3233,7 +3233,7 @@ name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.p -1.49814773,0.18675917,0.31246936,-1.32707596,0.42132780,-1.69121027,0.20342645,-0.34370381, === -name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoRelPosBias_NoBias.present_value_data +name:SelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias.present_value_data 3.28623724,-0.13420212,-2.11276555,1.99484074,0.80735362,-2.05807281,1.86690378,0.37204000, 0.60890561,-0.88021755,1.63002241,0.86171651,1.80559230,1.26110435,-0.97890180,-1.60215497, -0.78015935,0.48616353,0.05210955,-3.44942260,-1.85944068,-0.84834689,1.34473062,2.68511271, diff --git a/onnxruntime/test/testdata/attention/packed_multihead_attention_test_data.txt b/onnxruntime/test/testdata/attention/packed_multihead_attention_test_data.txt index 2e91cf46ce5f1..5bb83e7daa1ca 100644 --- a/onnxruntime/test/testdata/attention/packed_multihead_attention_test_data.txt +++ b/onnxruntime/test/testdata/attention/packed_multihead_attention_test_data.txt @@ -1,4 +1,4 @@ -name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.query_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.query_data -0.35420692,1.31206024,-2.80201197,2.42258096,-0.86031514,-1.44535458,-0.10832444,-2.00132895, 1.62475216,0.10978927,1.84596729,0.48908550,1.44369888,0.87542874,-1.16434252,0.52133209, 1.54848897,-2.21174526,-0.28574878,0.70815033,1.18327498,3.14097571,-0.25795099,1.89341247, @@ -14,7 +14,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.query_data -0.93303132,-0.84753871,-4.32799959,-1.94716609,-1.16980326,1.62631667,2.41053247,3.78186774, 0.26432252,-0.40396988,2.04414082,0.65150046,0.47777444,-2.57569051,0.99004912,2.47947693, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.key_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.key_data -0.04407793,1.29459429,1.05810797,1.92067695,-0.65047157,0.99029726,-1.69796586,1.15320420, -1.66444266,1.78305888,1.20582056,1.69975281,0.34572244,-0.60833001,2.59864879,-1.05330181, -1.16554165,-0.03781542,-1.13475525,0.71595150,-0.91169560,1.26686060,1.60492957,-0.53510487, @@ -30,7 +30,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.key_data 2.42824388,1.56369960,1.69934130,-0.42460468,-2.25951004,-1.18074155,3.51091242,-0.30183151, -1.83517075,-0.56233191,2.35561657,-3.63751698,-3.20001125,-1.66120780,3.23455381,-1.86251283, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.value_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.value_data -0.89167893,0.02633595,-0.84866279,1.43489110,-2.91941142,-0.20650116,1.85965109,0.45669034, 0.07678832,0.04492294,0.67326981,0.97103029,1.53470886,-1.10242307,0.86584085,-0.34770033, -1.24311507,-1.80293822,-1.01317739,-0.71518499,0.77814674,-0.59236068,-2.00310278,3.13277125, @@ -46,7 +46,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.value_data 1.14034331,-1.41539204,0.13379651,3.47018123,1.53924727,1.50004411,2.87318921,1.62624204, 0.64942807,-4.54302311,-1.50294220,-1.75212634,0.27900690,-3.05124855,3.30960631,-0.07991691, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.qkv_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.qkv_data -0.35420692,1.31206024,-2.80201197,2.42258096,-0.86031514,-1.44535458,-0.10832444,-2.00132895, 1.62475216,0.10978927,1.84596729,0.48908550,1.44369888,0.87542874,-1.16434252,0.52133209, 1.54848897,-2.21174526,-0.28574878,0.70815033,1.18327498,3.14097571,-0.25795099,1.89341247, @@ -86,7 +86,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.qkv_data 1.14034331,-1.41539204,0.13379651,3.47018123,1.53924727,1.50004411,2.87318921,1.62624204, 0.64942807,-4.54302311,-1.50294220,-1.75212634,0.27900690,-3.05124855,3.30960631,-0.07991691, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.fp16_output_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoAttnBias.fp16_output_data -0.89160156,0.02633667,-0.84863281,1.4345703,-2.9199219,-0.20654297,1.859375,0.45678711, 0.076782227,0.044921875,0.67333984,0.97119141,1.5351562,-1.1025391,0.86572266,-0.34765625, -1.2431641,-1.8027344,-1.0126953,-0.71533203,0.77832031,-0.59228516,-2.0039062,3.1328125, @@ -102,7 +102,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias.fp16_output_dat 1.08301103,-1.26178384,0.16304730,3.16210985,1.36142719,1.32916999,2.69524455,1.45106804, 0.67150640,-4.31703520,-1.34025633,-1.59496248,0.37821823,-2.85797405,3.11096096,-0.17414713f === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.query_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.query_data -1.83615911,0.08698978,0.05601556,-1.14510250,-2.30377889,-0.39893439,0.73342341,-0.09851928, -0.45148617,-0.16055907,-1.48271382,-0.07961921,-0.65701288,-0.25778309,-0.72851723,0.86755788, @@ -111,7 +111,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.query_data -0.20033565,-1.51847255,0.95205748,0.54009491,1.19315910,0.81655478,0.87503016,0.09732430, -0.53218621,-0.11167067,0.67364228,-0.59705222,-0.24946509,0.20462716,-0.56092483,-0.65660709, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.key_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.key_data 0.86949563,-0.10868365,-0.37917313,-1.23103046,0.25640076,-1.50652349,0.71594471,0.49057019, -1.41292810,-0.19686662,1.25451696,-1.59823179,-1.16262913,0.84965342,0.61178929,-1.26162946, @@ -120,7 +120,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.key_data 0.47295785,0.65468878,-1.44158995,-0.05122741,-0.34755200,0.66963655,0.72664660,1.59155345, -1.13806772,0.70947856,-0.65793264,-0.50718778,-1.20698619,0.32613355,0.61786091,-0.34040576, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.value_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.value_data -1.19203627,0.38844836,0.68121153,0.21624038,-1.77549291,0.18574584,0.90408206,-0.22868094, -0.95558548,1.38712502,0.81038797,0.14359820,0.15352470,0.00469783,0.03943123,0.53865469, @@ -129,7 +129,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.value_data -0.15860432,-0.24945745,0.67483073,0.18782829,-0.56960964,1.16764832,-0.72244978,0.55027384, -0.37327161,1.19222152,-0.23447749,0.06147140,0.32951999,1.06427121,2.26385999,0.23828916, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.qkv_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.qkv_data -1.83615911,0.08698978,0.05601556,-1.14510250,-2.30377889,-0.39893439,0.73342341,-0.09851928, 0.86949563,-0.10868365,-0.37917313,-1.23103046,0.25640076,-1.50652349,0.71594471,0.49057019, -1.19203627,0.38844836,0.68121153,0.21624038,-1.77549291,0.18574584,0.90408206,-0.22868094, @@ -154,14 +154,14 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.qkv_data -1.13806772,0.70947856,-0.65793264,-0.50718778,-1.20698619,0.32613355,0.61786091,-0.34040576, -0.37327161,1.19222152,-0.23447749,0.06147140,0.32951999,1.06427121,2.26385999,0.23828916, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.rel_pos_bias_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.attention_bias_data 0.4781123,0.82420444,0.654424,0.3995186, 0.5482078,0.55570245,0.4216576,0.46001542, 0.4781123,0.82420444,0.654424,0.3995186, 0.5482078,0.55570245,0.4216576,0.46001542, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.fp16_output_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_AttnBias.fp16_output_data -1.1923828,0.38842773,0.68115234,0.21618652,-1.7753906,0.18579102,0.90429688,-0.2286377, -0.95556641,1.3867188,0.81054688,0.14355469,0.15356445,0.004699707,0.039428711,0.53857422, @@ -172,7 +172,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_RelPosBias.fp16_output_data -0.17407227,0.57763672,-0.3046875,0.51025391,-0.097045898,0.98974609,1.0234375,0.47949219, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.query_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.query_data -1.83615911,0.08698978,0.05601556,-1.14510250,-2.30377889,-0.39893439,0.73342341,-0.09851928, -0.45148617,-0.16055907,-1.48271382,-0.07961921,-0.65701288,-0.25778309,-0.72851723,0.86755788, @@ -194,7 +194,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.query_dat -0.16418101,0.30182290,0.76461935,0.89762378,-0.70261180,1.31333566,0.86440170,-0.55341989, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.key_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.key_data 0.86949563,-0.10868365,-0.37917313,-1.23103046,0.25640076,-1.50652349,0.71594471,0.49057019, -1.41292810,-0.19686662,1.25451696,-1.59823179,-1.16262913,0.84965342,0.61178929,-1.26162946, @@ -216,7 +216,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.key_data -1.74471772,0.38858974,0.77225429,-0.47355813,0.59074765,-0.50501788,-1.72981727,-1.25862873, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.value_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.value_data -1.19203627,0.38844836,0.68121153,0.21624038,-1.77549291,0.18574584,0.90408206,-0.22868094, -0.95558548,1.38712502,0.81038797,0.14359820,0.15352470,0.00469783,0.03943123,0.53865469, @@ -238,7 +238,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.value_dat -1.04708695,1.04990900,0.61408597,0.48327276,0.61544299,-0.57864964,-0.80768973,0.39645281, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.qkv_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.qkv_data -1.83615911,0.08698978,0.05601556,-1.14510250,-2.30377889,-0.39893439,0.73342341,-0.09851928, 0.86949563,-0.10868365,-0.37917313,-1.23103046,0.25640076,-1.50652349,0.71594471,0.49057019, -1.19203627,0.38844836,0.68121153,0.21624038,-1.77549291,0.18574584,0.90408206,-0.22868094, @@ -312,7 +312,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.qkv_data -1.04708695,1.04990900,0.61408597,0.48327276,0.61544299,-0.57864964,-0.80768973,0.39645281, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.rel_pos_bias_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.attention_bias_data 0.09734076,-0.01747033,0.008497253,-0.03361112,-0.028750911,-0.017142132,-0.11563814,0.10432467, 0.057628587,0.030893803,-0.096876964,0.11924802,-0.009177148,0.05799888,-0.030559167,0.034150958, 0.07427484,0.028848544,-0.031371966,0.07186346,-0.093020484,-0.066411436,0.06858949,0.07350862, @@ -332,7 +332,7 @@ name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.rel_pos_b 0.013226762,-0.07403794,0.06855075,-0.06551643,-0.084110215,0.11237715,0.07026932,-0.014076158, === -name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastRelPosBias.fp16_output_data +name:PackedMultiHeadAttentionData_Batch2_HeadSize8_BroadcastAttnBias.fp16_output_data -1.1923828,0.38842773,0.68115234,0.21618652,-1.7753906,0.18579102,0.90429688,-0.2286377, -0.95556641,1.3867188,0.81054688,0.14355469,0.15356445,0.004699707,0.039428711,0.53857422, From c76f2940e4e077cbf2af279947d1c3752c2009f5 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 12 Aug 2024 19:21:59 +0000 Subject: [PATCH 02/13] broadcast attn bias in decoder masked mha --- .../decoder_masked_multihead_attention.cc | 14 +++++----- .../bert/decoder_masked_self_attention.cc | 8 +++--- ...decoder_masked_multihead_attention_impl.cu | 26 +++++++++++++------ .../decoder_masked_multihead_attention_impl.h | 4 +-- .../cuda/bert/multihead_attention.cc | 2 +- 5 files changed, 32 insertions(+), 22 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 037a4fdf3d9a0..350c4718c437e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -60,7 +60,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* const Tensor* key = context->Input(1); const Tensor* value = context->Input(2); const Tensor* mask_index = context->Input(3); - const Tensor* relative_position_bias = context->Input(4); + const Tensor* attention_bias = context->Input(4); const Tensor* past_key = context->Input(kPastInputIndex); const Tensor* past_value = context->Input(kPastInputIndex + 1); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); @@ -80,7 +80,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* value, bias, mask_index, - relative_position_bias, + attention_bias, past_key, past_value, past_seq_len, @@ -141,16 +141,16 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* // Update the q buffers parameters.q = const_cast(query->Data()); - // Update the relative position bias for self attention - if (relative_position_bias != nullptr) { - parameters.relative_attention_bias = const_cast(relative_position_bias->Data()); + // Update the attention bias for self attention + if (attention_bias != nullptr) { + parameters.attention_bias = const_cast(attention_bias->Data()); } // Decoder cross-attention if (past_key == nullptr && present_key == nullptr) { - if (relative_position_bias != nullptr) { + if (attention_bias != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "DecoderMaskedMultiHeadAttention does not support relative position bias for cross-attention"); + "DecoderMaskedMultiHeadAttention does not support attention bias for cross-attention"); } parameters.is_cross_attention = true; diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc index 07a6fbd60e171..e7d117686a538 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc @@ -45,7 +45,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(kPastInputIndex); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); const Tensor* beam_width = context->Input(kBeamWidthInputIndex); const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex); @@ -61,7 +61,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont bias->Shape(), mask_index, past, - relative_position_bias, + attention_bias, ¶meters, device_prop.maxThreadsPerBlock, past_seq_len)); @@ -85,8 +85,8 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont } // TODO(hasesh): If there is a need, we will support this later - if (relative_position_bias != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "DecoderMaskedSelfAttention does not support relative position bias currently"); + if (attention_bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "DecoderMaskedSelfAttention does not support attention bias currently"); } // TODO(hasesh): Support more mask types. Currently, it only supports the HuggingFace GreedySearch/BeamSearch pattern. diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 2f8d277cb7342..a0115ef9f2304 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -154,6 +154,18 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // The offset in the Q and K buffer also accounts for the batch. int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; + // The offset of attention bias for current head. + int64_t attn_bias_offset = 0; + if (params.attention_bias != nullptr && params.attention_bias_dims.size() == 4) { + // Support broadcasting the first and second dimensions of attention bias. + if (params.attention_bias_dims[0] > 1) { + attn_bias_offset = static_cast(bbi) * params.num_heads * params.sequence_length * params.total_sequence_length; + } + if (params.attention_bias_dims[1] > 1) { + attn_bias_offset += static_cast(hi) * params.sequence_length * params.total_sequence_length; + } + } + // Trigger the loads from the Q and K buffers. Qk_vec_k q; zero(q); @@ -286,9 +298,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio if (tidx == 0) { // Normalize qk. qk *= inv_sqrt_dh; - if (params.relative_attention_bias != nullptr) { - qk = add_vec(qk, - reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + tlength]); + if (params.attention_bias != nullptr) { + qk = add_vec(qk, reinterpret_cast(params.attention_bias)[attn_bias_offset + tlength]); } qk_max = qk; qk_smem[tlength] = qk; @@ -386,9 +397,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // Store the product to shared memory. There's one qk value per timestep. Update the max. if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - if (params.relative_attention_bias != nullptr) { - qk = add_vec(qk, - reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + ti]); + if (params.attention_bias != nullptr) { + qk = add_vec(qk, reinterpret_cast(params.attention_bias)[attn_bias_offset + ti]); } qk_max = fmaxf(qk_max, qk); qk_smem[ti] = qk; @@ -479,9 +489,9 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio #pragma unroll for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) { if (time_bounds_cond[k_unroll] && (tidx % THREADS_PER_KEY == 0)) { - if (params.relative_attention_bias != nullptr) { + if (params.attention_bias != nullptr) { qk[k_unroll] = add_vec(qk[k_unroll], - reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + time_step[k_unroll]]); + reinterpret_cast(params.attention_bias)[attn_bias_offset + time_step[k_unroll]]); } qk_max = fmaxf(qk_max, qk[k_unroll]); qk_smem[time_step[k_unroll]] = qk[k_unroll]; diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index 1a17757d1ec2d..efad33855328f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -37,7 +37,7 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { void* v = nullptr; void* v_bias = nullptr; - void* relative_attention_bias = nullptr; + void* attention_bias = nullptr; void* k_cache = nullptr; void* v_cache = nullptr; @@ -68,4 +68,4 @@ void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cud } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 58e41345431e1..da20521fb42d1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -249,7 +249,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { fused_cross_attention_kernel == nullptr && !disable_memory_efficient_attention_ && is_long_sequence && - // Check whether the relative position bias alignment is good for memory efficient attention. + // Check whether the attention bias alignment is good for memory efficient attention. (attention_bias == nullptr || parameters.sequence_length % (4 * sizeof(T)) == 0) && (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && has_memory_efficient_attention(sm, std::is_same::value, From a8cebba1125b3ac792be81a6c6456ea2a15350f5 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 13 Aug 2024 06:44:13 +0000 Subject: [PATCH 03/13] Add MHA tests --- onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h | 5 +++-- .../contrib_ops/cpu/bert/multihead_attention_helper.h | 2 +- onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu | 7 ------- .../cuda/bert/cutlass_fmha/fmha_launch_template.h | 2 +- onnxruntime/test/python/transformers/test_mha.py | 4 ++-- 5 files changed, 7 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index a49cf60655a49..d9516b6edc2c4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -222,8 +222,9 @@ class AttentionCPUBase : public AttentionBase { // B: K' (B x N x) T x H (B x N x) H x T H x T // C: attention_probs (B x N x) S x T (B x N x) S x T S x T math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha, - Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, output, - nullptr); + Q + q_input_chunk_length * i, k, + (mask_data != nullptr || attn_bias_data != nullptr) ? 1.0f : 0.0f, + output, nullptr); } }); } diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 85f223f8ec7a4..8a644147dbcca 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -181,7 +181,7 @@ Status CheckPast(const T* past_key, const T* past_value, const T* past_seq_len, inline Status CheckAttentionBias( const gsl::span& attention_bias_dims, - int batch_size, int num_heads, int sequence_length, int total_sequence_length) { + int64_t batch_size, int64_t num_heads, int64_t sequence_length, int64_t total_sequence_length) { if (attention_bias_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'attention_bias' is expected to have 4 dimensions, got ", diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index 494e708a85485..b11a6aa887039 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -50,7 +50,6 @@ __device__ inline void Softmax(const int total_sequence_length, // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is partitioned to blocks by TPB. const int sequence_length = gridDim.x; const int num_heads = gridDim.y; - const int batch_size = gridDim.z; const int s = blockIdx.x; const int n = blockIdx.y; const int b = blockIdx.z; @@ -138,7 +137,6 @@ __device__ inline void SoftmaxSmall(const int total_sequence_length, // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is within one block size TPB. const int sequence_length = gridDim.x; const int num_heads = gridDim.y; - const int batch_size = gridDim.z; const int s = blockIdx.x; const int n = blockIdx.y; const int b = blockIdx.z; @@ -226,7 +224,6 @@ __global__ void SoftmaxLargeKernel(const int total_sequence_length, // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is partitioned by TPB. const int sequence_length = gridDim.x; const int num_heads = gridDim.y; - const int batch_size = gridDim.z; const int s = blockIdx.x; const int n = blockIdx.y; const int b = blockIdx.z; @@ -319,12 +316,10 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length, __shared__ float max_block; float max_thread_data = -CUDART_INF_F; - const int size_per_batch = gridDim.x * total_sequence_length; // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is partitioned by TPB. const int sequence_length = gridDim.x; const int num_heads = gridDim.y; - const int batch_size = gridDim.z; const int s = blockIdx.x; const int n = blockIdx.y; const int b = blockIdx.z; @@ -443,7 +438,6 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length, // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is within one block size TPB. const int sequence_length = gridDim.x; const int num_heads = gridDim.y; - const int batch_size = gridDim.z; const int s = blockIdx.x; const int n = blockIdx.y; const int b = blockIdx.z; @@ -647,7 +641,6 @@ __device__ inline void SoftmaxSmallPacked(const int total_sequence_length, // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is within TPB. const int sequence_length = gridDim.x; const int num_heads = gridDim.y; - const int batch_size = gridDim.z; const int s = blockIdx.x; const int n = blockIdx.y; const int b = blockIdx.z; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 5029abe7e11e6..39d70bf1ea9bc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -216,7 +216,7 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { (bias_dims[1] == 1 || bias_dims[1] == params.num_heads) && bias_dims[2] == params.sequence_length && bias_dims[3] == params.kv_sequence_length); - p.bias_strideH = p.num_queries * p.num_keys; + p.bias_strideH = (bias_dims[1] == 1) ? 0 : p.num_queries * p.num_keys; p.bias_strideM = p.num_keys; p.bias_strideB = (bias_dims[0] == 1) ? 0 : (bias_dims[1] * p.num_queries * p.num_keys); } else { diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index f6a837ddf829a..5ebc02c84acb2 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -74,8 +74,8 @@ def get_atten_bias_support(): (False, False, False), (True, False, False), # [b, n, s_q, s_kv] (True, True, False), # [1, n, s_q, s_kv] - # (True, False, True), # [b, 1, s_q, s_kv] - # (True, True, True), # [1, 1, s_q, s_kv] + (True, False, True), # [b, 1, s_q, s_kv] + (True, True, True), # [1, 1, s_q, s_kv] ] return atten_bias_options From c728b0beb344e3a1e84b7b04ce0af1b60b48407a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 13 Aug 2024 07:14:30 +0000 Subject: [PATCH 04/13] rename relative_position_bias to attention_bias --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 2 +- .../jsep/webgpu/ops/multihead-attention.ts | 2 +- onnxruntime/contrib_ops/cpu/bert/attention.cc | 6 ++-- .../cpu/quantization/attention_quant.cc | 2 +- .../contrib_ops/cuda/bert/packed_attention.h | 2 +- .../cuda/bert/packed_multihead_attention.h | 2 +- .../quantization/attention_quantization.cc | 2 +- .../qordered_ops/qordered_attention.cc | 2 +- .../qordered_attention_input_enum.h | 2 +- .../core/graph/contrib_ops/bert_defs.cc | 25 ++++++++--------- .../graph/contrib_ops/quantization_defs.cc | 6 ++-- .../core/providers/cpu/cpu_provider_shared.cc | 4 +-- .../core/providers/cpu/cpu_provider_shared.h | 2 +- .../provider_bridge_provider.cc | 4 +-- .../tools/transformers/convert_generation.py | 2 +- .../transformers/fusion_rotary_attention.py | 2 +- .../test/contrib_ops/attention_op_test.cc | 28 +++++++++---------- .../multihead_attention_op_test.cc | 8 +++--- .../contrib_ops/packed_attention_op_test.cc | 24 ++++++++-------- .../contrib_ops/qordered_attention_test.cc | 2 +- .../test_parity_neox_attention.py | 2 +- .../python/transformers/test_parity_t5_mha.py | 14 ++++------ 22 files changed, 71 insertions(+), 74 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 31a8823447f0d..74bfb556e3602 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -218,7 +218,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte attentionBias.dims[1] !== attributes.numHeads || attentionBias.dims[2] !== sequenceLength || attentionBias.dims[3] !== totalSequenceLength) { - throw new Error('Input "attention_bias" shape shall be (batch_size, num_heads, sequence_length, total_sequence_length)'); + throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)'); } } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index c83bf1481e109..d9c0a62c6479e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -222,7 +222,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr attentionBias.dims[1] !== attributes.numHeads || attentionBias.dims[2] !== sequenceLength || attentionBias.dims[3] !== totalSequenceLength) { - throw new Error('Input "attention_bias" shape shall be (batch_size, num_heads, sequence_length, total_sequence_length)'); + throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)'); } } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 768676259aa14..ad14fb8258656 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -198,7 +198,7 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); - const Tensor* relative_position_bias = context->Input(5); + const Tensor* attention_bias = context->Input(5); const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_); @@ -208,7 +208,7 @@ Status Attention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past, - relative_position_bias, + attention_bias, ¶meters)); if (parameters.do_rotary) { @@ -338,7 +338,7 @@ Status Attention::Compute(OpKernelContext* context) const { output, nullptr /* present_key */, nullptr /* present_value */, batch_size, sequence_length, sequence_length, parameters.head_size, parameters.v_head_size, parameters.v_hidden_size, - relative_position_bias, context); + attention_bias, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 6201b892a89b0..2c897f183164f 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -160,7 +160,7 @@ Status QAttention::Compute(OpKernelContext* context) const { bias->Shape(), mask_index, past_tensor, - nullptr, // relative_position_bias + nullptr, // attention_bias nullptr // parameters )); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h index 67b420764169a..cad28e7b70057 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h @@ -46,7 +46,7 @@ class PackedAttention final : public TrtFusedAttention { const TensorShape& bias_shape, const TensorShape& packing_token_offset_shape, const TensorShape& cu_seq_len_shape, - const Tensor* relative_position_bias, + const Tensor* attention_bias, PackedAttentionParameters& parameters) const; int GetNumHeads() const { return num_heads_; } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h index 9b52a70fc6181..3e59ce3dd229e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h @@ -23,7 +23,7 @@ class PackedMultiHeadAttention final : public TrtFusedAttention { const Tensor* bias, const TensorShape& token_offset_shape, const TensorShape& cu_seq_len_shape, - const Tensor* relative_position_bias, + const Tensor* attention_bias, PackedAttentionParameters& parameters) const; int GetNumHeads() const { return num_heads_; } float GetScale() const { return scale_; } diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index b62e566d43f89..3a5fc401c53af 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -52,7 +52,7 @@ Status QAttention::CheckInputs(const Tensor* input, auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past_tensor, - nullptr, // relative_position_bias + nullptr, // attention_bias parameters, device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 12835978536e1..3e93a527877c5 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -199,7 +199,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), merged_weights_shape, merged_bias_shape, mask_index, nullptr, // past - nullptr, // relative_position_bias + nullptr, // attention_bias nullptr, // parameters device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h index b4b501856a52e..62c1679743429 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_input_enum.h @@ -17,4 +17,4 @@ DefineQOrderedAttentionInput(Input, input, 0), DefineQOrderedAttentionInput(Scale_Values_Gemm, scale_values_gemm, 16), DefineQOrderedAttentionInput(Mask_Index, mask_index, 17), DefineQOrderedAttentionInput(Past, past, 18), - DefineQOrderedAttentionInput(relative_position_bias, relative_position_bias, 19) + DefineQOrderedAttentionInput(attention_bias, attention_bias, 19) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 0745dcdf231e6..334090e8f305f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -421,8 +421,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T", OpSchema::Optional) .Input(5, - "relative_position_bias", - "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", + "attention_bias", + "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) .Input(6, @@ -482,7 +482,7 @@ The operator only supports BERT like model with padding on right now. // Input 'bias': (hidden_size + hidden_size + v_hidden_size) // Input 'token_offset': (batch_size, sequence_length) // Input 'cumulative_sequence_length': (batch_size + 1) -// Input 'relative_position_bias': (batch_size, num_heads, sequence_length, sequence_length) +// Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) // Output 'output': (token_count, v_hidden_size) void PackedAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference @@ -560,9 +560,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.", "M") .Input(5, - "relative_position_bias", - "A tensor with shape (batch_size, num_heads, sequence_length, sequence_length)" - "or (1, num_heads, sequence_length, sequence_length)." + "attention_bias", + "A tensor with shape (batch_size or 1, num_heads or 1, sequence_length, sequence_length)." "It specifies the additional bias to QxK'", "T", OpSchema::Optional) @@ -616,7 +615,7 @@ The operator only supports BERT like model with padding on right now. // Input 'bias': (hidden_size + hidden_size + v_hidden_size) // Input 'token_offset': (batch_size, sequence_length) // Input 'cumulative_sequence_length': (batch_size + 1) -// Input 'relative_position_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None +// Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) or None // Output 'output': (token_count, v_hidden_size) void PackedMultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference @@ -694,9 +693,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.", "M") .Input(6, - "relative_position_bias", - "It specifies the additional bias to QxK'. The shape is (batch_size, num_heads, sequence_length, sequence_length)" - " or (1, num_heads, sequence_length, sequence_length)", + "attention_bias", + "It specifies the additional bias to QxK'. " + "The shape is (batch_size or 1, num_heads or 1, sequence_length, sequence_length)", "T", OpSchema::Optional) .Output(0, @@ -778,8 +777,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.", "T") .Input(5, - "relative_position_bias", - "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", + "attention_bias", + "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) .Input(6, @@ -871,7 +870,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M", OpSchema::Optional) .Input(4, - "relative_position_bias", + "attention_bias", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", "T", OpSchema::Optional) diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 762d892c45ce8..6f1f1c831d191 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -1146,7 +1146,7 @@ where value of each element is the end position, or valid length of actual seque left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. -Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. +Current version does not support past/present, attention_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. )DOC"; @@ -1208,8 +1208,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(18, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "Q", OpSchema::Optional) - .Input(19, "relative_position_bias", - "additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).", "S", + .Input(19, "attention_bias", + "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length).", "S", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "Q") .TypeConstraint("Q", {"tensor(int8)"}, "Constrain input and output types to int8 tensors.") diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index fd7b19dea724d..ce9780031a250 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -225,12 +225,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) override { return p->contrib::AttentionBase::CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, - relative_position_bias, + attention_bias, parameters, max_threads_per_block, past_seq_len); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 840d6f8e3e7aa..eb1569c3e499e 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -163,7 +163,7 @@ struct ProviderHostCPU { const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 7fb9fd3c8cfd5..252ce9298bda8 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -608,12 +608,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, - const Tensor* relative_position_bias, + const Tensor* attention_bias, void* parameters, const int max_threads_per_block, const Tensor* past_seq_len) const { return g_host_cpu.AttentionBase__CheckInputs(this, input_shape, weights_shape, bias_shape, - mask_index, past, relative_position_bias, parameters, + mask_index, past, attention_bias, parameters, max_threads_per_block, past_seq_len); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, const Tensor* past, int batch_size, int head_size, diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 894e11275056e..5a26fedb5287d 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1624,7 +1624,7 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelP ] nis.extend([node.input[4] if len(node.input) > 4 else ""]) # 2D mask - nis.extend([node.input[5] if len(node.input) > 5 else ""]) # relative_position_bias + nis.extend([node.input[5] if len(node.input) > 5 else ""]) # attention_bias nis.extend([node.input[6] if len(node.input) > 6 else ""]) # past_key nis.extend([node.input[7] if len(node.input) > 7 else ""]) # past_value nis.extend(["past_sequence_length"]) # past_sequence_length diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index 7384cace21a67..efdcbcfb3dcdc 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -68,7 +68,7 @@ def create_mha_node( v_matmul.output[0], "", # bias attn_mask, # key_padding_mask - add_qk, # relative_position_bias + add_qk, # attention_bias past_k, past_v, ] diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index a8e2fccdd0462..602966495f1cd 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -60,7 +60,7 @@ static void RunAttentionTest( const bool disable_rocm = false, const bool disable_dml = false, std::vector qkv_sizes = {}, - const std::vector& relative_position_bias_data = {}, + const std::vector& attention_bias_data = {}, int kv_sequence_length = 0, bool past_present_share_buffer = false, bool use_scale = false, @@ -205,12 +205,12 @@ static void RunAttentionTest( } } - std::vector relative_position_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; - if (relative_position_bias_data.size() > 0) { + std::vector attention_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; + if (attention_bias_data.size() > 0) { if (use_float16) { - tester.AddInput("relative_position_bias", relative_position_bias_data_dims, ToFloat16(relative_position_bias_data)); + tester.AddInput("attention_bias", attention_bias_data_dims, ToFloat16(attention_bias_data)); } else { - tester.AddInput("relative_position_bias", relative_position_bias_data_dims, relative_position_bias_data); + tester.AddInput("attention_bias", attention_bias_data_dims, attention_bias_data); } } else { if (use_float16) { @@ -292,7 +292,7 @@ static void RunAttentionTest( const bool disable_rocm = false, const bool disable_dml = false, const std::vector qkv_sizes = {}, - const std::vector& relative_position_bias_data = {}, + const std::vector& attention_bias_data = {}, int kv_sequence_length = 0, bool past_present_share_buffer = false, bool use_scale = false, @@ -301,13 +301,13 @@ static void RunAttentionTest( batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, relative_position_bias_data, + disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias_data, kv_sequence_length, past_present_share_buffer, use_scale, do_neox_rotary); RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, - disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, relative_position_bias_data, + disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias_data, kv_sequence_length, past_present_share_buffer, use_scale, do_neox_rotary); } @@ -443,7 +443,7 @@ TEST(AttentionTest, AttentionBatch1RelativePositionBias) { std::vector mask_index_data = {2L}; - std::vector relative_position_bias = { + std::vector attention_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; std::vector output_data = { @@ -457,7 +457,7 @@ TEST(AttentionTest, AttentionBatch1RelativePositionBias) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, relative_position_bias); + 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias); } TEST(AttentionTest, AttentionBatch2RelativePositionBias) { @@ -486,7 +486,7 @@ TEST(AttentionTest, AttentionBatch2RelativePositionBias) { std::vector mask_index_data = {2L, 2L}; - std::vector relative_position_bias = { + std::vector attention_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f, 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; @@ -503,7 +503,7 @@ TEST(AttentionTest, AttentionBatch2RelativePositionBias) { RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, - 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, relative_position_bias); + 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias); } TEST(AttentionTest, AttentionBatch1_Float16) { @@ -1679,7 +1679,7 @@ TEST(AttentionTest, AttentionWithNormFactor) { use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/, false /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, false /*disable_dml*/, {} /*qkv_sizes*/, - {} /*relative_position_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, + {} /*attention_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, true /*use_scale*/); } @@ -1713,7 +1713,7 @@ TEST(AttentionTest, AttentionWithNeoXRotaryEmbedding) { use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/, true /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, disable_dml, {} /*qkv_sizes*/, - {} /*relative_position_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, + {} /*attention_bias_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, true /*use_scale*/, true /*use_neox_rotary_embedding*/); } diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 65727828f51fb..3aaf710c33db4 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -31,7 +31,7 @@ static void RunMultiHeadAttentionTest( const std::vector& kv_data, // packed_kv: [batch_size, kv_sequence_length, num_heads, 2, head_size] const std::vector& qkv_data, // packed_qkv: [batch_size, sequence_length, num_heads, 3, head_size] const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size] or empty - const std::vector& attention_bias_data, // relative_position_bias: [1, num_heads, sequence_length, total_sequence_length] + const std::vector& attention_bias_data, // attention_bias: [1, num_heads, sequence_length, total_sequence_length] const std::vector& past_key_data, // past_key: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& past_value_data, // past_value: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& present_key_data, // present_key: [batch_size, num_heads, total_sequence_length, head_size] @@ -145,7 +145,7 @@ static void RunMultiHeadAttentionTest( } if (attention_bias_data.size()) { - tester.AddInput("relative_position_bias", attention_bias_dims, ToFloat16(attention_bias_data)); + tester.AddInput("attention_bias", attention_bias_dims, ToFloat16(attention_bias_data)); } else { tester.AddOptionalInputEdge(); } @@ -209,7 +209,7 @@ static void RunMultiHeadAttentionTest( } if (attention_bias_data.size()) { - tester.AddInput("relative_position_bias", attention_bias_dims, attention_bias_data); + tester.AddInput("attention_bias", attention_bias_dims, attention_bias_data); } else { tester.AddOptionalInputEdge(); } @@ -276,7 +276,7 @@ static void RunMultiHeadAttentionKernel( const std::vector& kv_data, // packed_kv: [batch_size, kv_sequence_length, num_heads, 2, head_size] const std::vector& qkv_data, // packed_qkv: [batch_size, sequence_length, num_heads, 3, head_size] const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size] - const std::vector& attention_bias_data, // relative_position_bias: [1, num_heads, sequence_length, total_sequence_length] + const std::vector& attention_bias_data, // attention_bias: [1, num_heads, sequence_length, total_sequence_length] const std::vector& past_key_data, // past_key: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& past_value_data, // past_value: [batch_size, num_heads, kv_sequence_length, head_size] const std::vector& present_key_data, // present_key: [batch_size, num_heads, total_sequence_length, head_size] diff --git a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc index 09baf8def05f6..f87d464f0d952 100644 --- a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc @@ -30,7 +30,7 @@ static void RunPackedAttentionTest( bool use_float16, bool use_scale, std::vector qkv_sizes, - const std::vector& relative_position_bias_data) { + const std::vector& attention_bias_data) { int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); @@ -62,7 +62,7 @@ static void RunPackedAttentionTest( std::vector bias_dims = {qkv_hidden_size_sum}; std::vector token_offset_dims = {batch_size, sequence_length}; std::vector cum_seq_len_dims = {batch_size + 1}; - std::vector relative_position_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; + std::vector attention_bias_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; std::vector output_dims = {token_count, v_hidden_size}; if (use_float16) { tester.AddInput("input", input_dims, ToFloat16(input_data)); @@ -70,8 +70,8 @@ static void RunPackedAttentionTest( tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); tester.AddInput("token_offset", token_offset_dims, token_offset); tester.AddInput("cumulative_sequence_length", cum_seq_len_dims, cumulative_sequence_length); - if (relative_position_bias_data.size() > 0) { - tester.AddInput("relative_position_bias", relative_position_bias_data_dims, ToFloat16(relative_position_bias_data)); + if (attention_bias_data.size() > 0) { + tester.AddInput("attention_bias", attention_bias_data_dims, ToFloat16(attention_bias_data)); } tester.AddOutput("output", output_dims, ToFloat16(output_data)); @@ -81,8 +81,8 @@ static void RunPackedAttentionTest( tester.AddInput("bias", bias_dims, bias_data); tester.AddInput("token_offset", token_offset_dims, token_offset); tester.AddInput("cumulative_sequence_length", cum_seq_len_dims, cumulative_sequence_length); - if (relative_position_bias_data.size() > 0) { - tester.AddInput("relative_position_bias", relative_position_bias_data_dims, relative_position_bias_data); + if (attention_bias_data.size() > 0) { + tester.AddInput("attention_bias", attention_bias_data_dims, attention_bias_data); } tester.AddOutput("output", output_dims, output_data); @@ -107,7 +107,7 @@ static void RunPackedAttentionTest( int number_of_heads, int token_count, std::vector qkv_sizes = {}, - const std::vector& relative_position_bias_data = {}) { + const std::vector& attention_bias_data = {}) { #define InvokePackedAttentionTest(use_float16, use_scale) \ RunPackedAttentionTest( \ input_data, \ @@ -124,7 +124,7 @@ static void RunPackedAttentionTest( use_float16, \ use_scale, \ qkv_sizes, \ - relative_position_bias_data); + attention_bias_data); InvokePackedAttentionTest(true, true); InvokePackedAttentionTest(true, false); @@ -197,7 +197,7 @@ TEST(PackedAttentionTest, NoPackWithRelativePositionBias) { std::vector token_offset{0, 1, 2, 3}; std::vector cum_seq_len{0, 2, 4}; - std::vector relative_position_bias = { + std::vector attention_bias = { 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f, 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; @@ -220,7 +220,7 @@ TEST(PackedAttentionTest, NoPackWithRelativePositionBias) { number_of_heads, batch_size * sequence_length, {}, - relative_position_bias); + attention_bias); } TEST(PackedAttentionTest, PackedWithRelativePositionBias) { @@ -249,7 +249,7 @@ TEST(PackedAttentionTest, PackedWithRelativePositionBias) { std::vector token_offset{0, 1, 4, 5, 2, 3, 6, 7}; std::vector cum_seq_len{0, 2, 4}; - std::vector relative_position_bias = { + std::vector attention_bias = { 0.2f, -0.1f, 0.f, 0.f, 0.4f, 2.5f, 0.f, 0.f, 1.6f, -1.1f, 0.f, 0.f, 0.4f, -2.5f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, @@ -279,7 +279,7 @@ TEST(PackedAttentionTest, PackedWithRelativePositionBias) { number_of_heads, 4, {}, - relative_position_bias); + attention_bias); } TEST(PackedAttentionTest, PackedBatch) { diff --git a/onnxruntime/test/contrib_ops/qordered_attention_test.cc b/onnxruntime/test/contrib_ops/qordered_attention_test.cc index 1dd0162ad722f..b7cd3948b0e76 100644 --- a/onnxruntime/test/contrib_ops/qordered_attention_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_attention_test.cc @@ -272,7 +272,7 @@ TEST(QOrderedTest, Attention_WithData_ROW_ORDER) { test_qorder.AddInput("scale_values_gemm", {}, {attn_out_scale}, true); test_qorder.AddInput("mask_index", {batch_size, sequence_len}, input_mask.data(), input_mask.size()); test_qorder.AddOptionalInputEdge(); // past - test_qorder.AddOptionalInputEdge(); // relative_position_bias + test_qorder.AddOptionalInputEdge(); // attention_bias test_qorder.AddOutput("output", {batch_size, sequence_len, hidden_size}, attn_out_q8.data(), attn_out_q8.size()); diff --git a/onnxruntime/test/python/transformers/test_parity_neox_attention.py b/onnxruntime/test/python/transformers/test_parity_neox_attention.py index d0a308987d888..300de19dd34c2 100644 --- a/onnxruntime/test/python/transformers/test_parity_neox_attention.py +++ b/onnxruntime/test/python/transformers/test_parity_neox_attention.py @@ -89,7 +89,7 @@ def create_neox_decoder_masked_self_attention_graph( "bias", "mask_index", "past", - "", # relative_position_bias + "", # attention_bias "past_sequence_length", ], ["output", "present"], diff --git a/onnxruntime/test/python/transformers/test_parity_t5_mha.py b/onnxruntime/test/python/transformers/test_parity_t5_mha.py index c7fb398dde82e..e4f65b07c552e 100644 --- a/onnxruntime/test/python/transformers/test_parity_t5_mha.py +++ b/onnxruntime/test/python/transformers/test_parity_t5_mha.py @@ -57,7 +57,7 @@ def create_t5_mha_graph( "value" if use_present or is_static_kv else "", "", # bias "key_padding_mask" if use_mask else "", - "relative_position_bias" if use_rpb else "", + "attention_bias" if use_rpb else "", "past_key" if use_past and not is_static_kv else "", "past_value" if use_past and not is_static_kv else "", ], @@ -93,9 +93,7 @@ def create_t5_mha_graph( if use_rpb: graph_inputs.append( - helper.make_tensor_value_info( - "relative_position_bias", TensorProto.FLOAT, [1, num_heads, seq_len, rpb_length] - ) + helper.make_tensor_value_info("attention_bias", TensorProto.FLOAT, [1, num_heads, seq_len, rpb_length]) ) if use_past and not is_static_kv: @@ -170,7 +168,7 @@ def create_t5_decoder_masked_mha_graph( "key", "value", "mask_index" if is_cross_attention else "", - "relative_position_bias" if not is_cross_attention else "", + "attention_bias" if not is_cross_attention else "", "past_key" if not is_cross_attention else "", "past_value" if not is_cross_attention else "", "past_sequence_length" if not is_cross_attention else "", @@ -220,7 +218,7 @@ def create_t5_decoder_masked_mha_graph( graph_inputs.append(helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, 1, hidden_size])) graph_inputs.append( helper.make_tensor_value_info( - "relative_position_bias", TensorProto.FLOAT, [1, num_heads, 1, past_sequence_length + 1] + "attention_bias", TensorProto.FLOAT, [1, num_heads, 1, past_sequence_length + 1] ) ) # use past_sequence_length + 1 to simulate max_sequence_length @@ -558,7 +556,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if torch_key_padding_mask is not None: ort_inputs["key_padding_mask"] = np.ascontiguousarray(torch_key_padding_mask.detach().numpy()) if torch_position_bias is not None: - ort_inputs["relative_position_bias"] = np.ascontiguousarray(torch_position_bias.detach().numpy()) + ort_inputs["attention_bias"] = np.ascontiguousarray(torch_position_bias.detach().numpy()) else: torch_past_key = past_key_value[0] torch_past_value = past_key_value[1] @@ -617,7 +615,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): else: ort_inputs["key_padding_mask"] = np.ascontiguousarray(torch_key_padding_mask.detach().numpy()) if torch_position_bias is not None: - ort_inputs["relative_position_bias"] = np.ascontiguousarray(torch_position_bias.detach().numpy()) + ort_inputs["attention_bias"] = np.ascontiguousarray(torch_position_bias.detach().numpy()) ort_output = ort_session.run(None, ort_inputs) From 2bff1881ff3f931d042e89252681ea3ba56a2d11 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 14 Aug 2024 16:47:39 -0700 Subject: [PATCH 05/13] fix build --- onnxruntime/contrib_ops/cpu/utils/console_dumper.h | 2 +- .../cuda/bert/decoder_masked_multihead_attention.cc | 2 ++ .../decoder_masked_multihead_attention_impl.cu | 6 +++--- .../decoder_masked_multihead_attention_impl.h | 2 ++ .../src/External/DirectMLHelpers/GeneratedSchemaHelpers.h | 4 ++-- .../src/Operators/DmlOperatorAttention.cpp | 2 +- .../src/Operators/DmlOperatorMultiHeadAttention.cpp | 2 +- .../src/Operators/DmlOperatorQAttention.cpp | 2 +- onnxruntime/test/contrib_ops/attention_op_test.cc | 4 ++-- onnxruntime/test/contrib_ops/packed_attention_op_test.cc | 4 ++-- 10 files changed, 17 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h index ff7921fc70da3..12cbc5049a02a 100644 --- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h +++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h @@ -53,7 +53,7 @@ void PrintTensorByDims(const TConsoleDumper* dumper, const char* name, const T* tensor, gsl::span& dims) { - if (dumper->IsEnabled && (tensor == nullptr || dims.size() == 0)) { + if (dumper->IsEnabled() && (tensor == nullptr || dims.size() == 0)) { std::cout << std::string(name) << " is None" << std::endl; return; } diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 350c4718c437e..3070d98a77cb4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -144,6 +144,8 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* // Update the attention bias for self attention if (attention_bias != nullptr) { parameters.attention_bias = const_cast(attention_bias->Data()); + parameters.broadcast_attention_bias_dim_0 = parameters.attention_bias_dims[0] == 1; + parameters.broadcast_attention_bias_dim_1 = parameters.attention_bias_dims[1] == 1; } // Decoder cross-attention diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index a0115ef9f2304..2bbaf50ec55da 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -156,12 +156,12 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // The offset of attention bias for current head. int64_t attn_bias_offset = 0; - if (params.attention_bias != nullptr && params.attention_bias_dims.size() == 4) { + if (params.attention_bias != nullptr) { // Support broadcasting the first and second dimensions of attention bias. - if (params.attention_bias_dims[0] > 1) { + if (!params.broadcast_attention_bias_dim_0) { attn_bias_offset = static_cast(bbi) * params.num_heads * params.sequence_length * params.total_sequence_length; } - if (params.attention_bias_dims[1] > 1) { + if (!params.broadcast_attention_bias_dim_1) { attn_bias_offset += static_cast(hi) * params.sequence_length * params.total_sequence_length; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index efad33855328f..2cb276f82bf18 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -38,6 +38,8 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { void* v_bias = nullptr; void* attention_bias = nullptr; + bool broadcast_attention_bias_dim_0 = false; + bool broadcast_attention_bias_dim_1 = false; void* k_cache = nullptr; void* v_cache = nullptr; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 04ad595b241b0..23b5a491c7d96 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -1471,7 +1471,7 @@ inline std::vector GetFields(const DML_MULTIHEAD_ATTENTION_OPERAT OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StackedQueryKeyValueTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.MaskTensor))), - OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.AttentionBiasTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.RelativePositionBiasTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.PastKeyTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.PastValueTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.OutputTensor))), @@ -1566,7 +1566,7 @@ inline std::vector GetFields(const DML_MULTIHEAD_ATTENTION1_OPERA OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StackedQueryKeyValueTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.MaskTensor))), - OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.AttentionBiasTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.RelativePositionBiasTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.PastKeyTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.PastValueTensor))), OperatorField(&DML_MULTIHEAD_ATTENTION1_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.PastSequenceLengthsTensor))), diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index 5409d1c653d47..f913fb2e02a75 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -346,7 +346,7 @@ class DmlOperatorAttention : public DmlOperator mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[dmlMaskIndex] : nullptr; } - mhaOperatorDesc.AttentionBiasTensor = hasAttentionBias ? &inputDescs[dmlAttentionBiasIndex] : nullptr; + mhaOperatorDesc.RelativePositionBiasTensor = hasAttentionBias ? &inputDescs[dmlAttentionBiasIndex] : nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp index 96d2408f118a6..1e747e34c9acb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -283,7 +283,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator mhaDesc.StackedQueryKeyValueTensor = stackedQkv ? &inputDescs[dmlStackedQueryKeyValueIndex] : nullptr; mhaDesc.BiasTensor = hasBias ? &inputDescs[dmlBiasIndex] : nullptr; mhaDesc.MaskTensor = hasMask ? &inputDescs[dmlMaskIndex] : nullptr; - mhaDesc.AttentionBiasTensor = hasAttentionBias ? &inputDescs[dmlAttentionBiasIndex] : nullptr; + mhaDesc.RelativePositionBiasTensor = hasAttentionBias ? &inputDescs[dmlAttentionBiasIndex] : nullptr; mhaDesc.PastKeyTensor = hasPastKey ? &inputDescs[dmlPastKeyIndex] : nullptr; mhaDesc.PastValueTensor = hasPastValue ? &inputDescs[dmlPastValueIndex] : nullptr; mhaDesc.OutputTensor = &outputDescs[outputIndex]; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index f45cb6c90b352..d6fd83fd583de 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -415,7 +415,7 @@ class DmlOperatorQAttention : public DmlOperator mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[maskIndex] : nullptr; } - mhaOperatorDesc.AttentionBiasTensor = nullptr; + mhaOperatorDesc.RelativePositionBiasTensor = nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); // Set MaskFilterValue to lowest float for Causal Mask diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 602966495f1cd..61e5fa05c66c1 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -419,7 +419,7 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { 0, false, false, disable_rocm, false, qkv_sizes); } -TEST(AttentionTest, AttentionBatch1RelativePositionBias) { +TEST(AttentionTest, AttentionBatch1AttentionBias) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -460,7 +460,7 @@ TEST(AttentionTest, AttentionBatch1RelativePositionBias) { 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias); } -TEST(AttentionTest, AttentionBatch2RelativePositionBias) { +TEST(AttentionTest, AttentionBatch2AttentionBias) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; diff --git a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc index f87d464f0d952..96c629b4616d5 100644 --- a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc @@ -172,7 +172,7 @@ TEST(PackedAttentionTest, NoPack) { batch_size * sequence_length); } -TEST(PackedAttentionTest, NoPackWithRelativePositionBias) { +TEST(PackedAttentionTest, NoPackWithAttentionBias) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -223,7 +223,7 @@ TEST(PackedAttentionTest, NoPackWithRelativePositionBias) { attention_bias); } -TEST(PackedAttentionTest, PackedWithRelativePositionBias) { +TEST(PackedAttentionTest, PackedWithAttentionBias) { int batch_size = 2; int sequence_length = 4; int hidden_size = 4; From 58792dd50f5d580d0df42f1521623d39bd117e77 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 15 Aug 2024 03:35:32 +0000 Subject: [PATCH 06/13] update doc --- docs/ContribOperators.md | 28 ++++++++++++++-------------- docs/OperatorKernels.md | 22 +++++++++++----------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index c60b25f3418f6..0048190f9063b 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -180,8 +180,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size) or (3 * batch_size + 2)
past (optional) : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
-
relative_position_bias (optional) : T
-
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
+
attention_bias (optional) : T
+
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
past_sequence_length (optional) : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
@@ -1166,7 +1166,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Value with shape (batch_size, 1, v_hidden_size) for self attention or past_value with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention
mask_index (optional) : M
Mask values of shape (batch_size, total_sequence_length) or (batch_size, kv_sequence_length)
-
relative_position_bias (optional) : T
+
attention_bias (optional) : T
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
past_key (optional) : T
past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) for self attentionWhen past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size). The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.
@@ -1256,8 +1256,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Mask values of shape (batch_size, total_sequence_length)
past : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size). The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.
-
relative_position_bias (optional) : T
-
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
+
attention_bias (optional) : T
+
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
past_sequence_length : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
beam_width (optional) : M
@@ -3202,8 +3202,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)
-
relative_position_bias (optional) : T
-
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
+
attention_bias (optional) : T
+
bias added to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
past_key (optional) : T
past state for self attention key with shape (batch_size, num_heads, past_sequence_length, head_size)
past_value (optional) : T
@@ -3516,8 +3516,8 @@ This version of the operator has been available since version 1 of the 'com.micr
In packing mode, it specifies the offset of each token(batch_size, sequence_length).
cumulative_sequence_length : M
A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.
-
relative_position_bias (optional) : T
-
A tensor with shape (batch_size, num_heads, sequence_length, sequence_length)or (1, num_heads, sequence_length, sequence_length).It specifies the additional bias to QxK'
+
attention_bias (optional) : T
+
A tensor with shape (batch_size or 1, num_heads or 1, sequence_length, sequence_length).It specifies the additional bias to QxK'
#### Outputs @@ -3591,8 +3591,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Offset of each token before packing, with shape (batch_size, sequence_length).
cumulative_sequence_length : M
A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.
-
relative_position_bias (optional) : T
-
It specifies the additional bias to QxK'. The shape is (batch_size, num_heads, sequence_length, sequence_length) or (1, num_heads, sequence_length, sequence_length)
+
attention_bias (optional) : T
+
It specifies the additional bias to QxK'. The shape is (batch_size or 1, num_heads or 1, sequence_length, sequence_length)
#### Outputs @@ -4468,7 +4468,7 @@ This version of the operator has been available since version 1 of the 'com.micr left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past and present state are optional. Present state could appear in output even when past state is not in input. - Current version does not support past/present, relative_position_bias and qkv_hidden_sizes. + Current version does not support past/present, attention_bias and qkv_hidden_sizes. TODO: Support them if needed in the future. #### Version @@ -4533,8 +4533,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
past (optional) : Q
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).
-
relative_position_bias (optional) : S
-
additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).
+
attention_bias (optional) : S
+
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length).
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index f0aa332ff39eb..96173b5a4ea4a 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -460,7 +460,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| |BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| @@ -490,7 +490,7 @@ Do not modify directly.* |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| -|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| +|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcMaxPool|*in* x:**T**
*out* y:**T**|1+|**T** = tensor(int8), tensor(uint8)| @@ -848,7 +848,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| @@ -861,8 +861,8 @@ Do not modify directly.* |ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)| -|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**V**|1+|**T** = tensor(float), tensor(float16)| -|DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**V**|1+|**T** = tensor(float), tensor(float16)| +|DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)| |DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| |DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)| @@ -884,14 +884,14 @@ Do not modify directly.* |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| +|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)| -|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| +|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLongformerAttention|*in* input:**Q**
*in* scale_input:**S**
*in* weight:**Q**
*in* scale_weight:**S**
*in* bias:**S**
*in* scale_bias:**S**
*in* scale_qkv_gemm:**S**
*in* mask:**F**
*in* global_weight:**Q**
*in* scale_global_weight:**S**
*in* global_bias:**S**
*in* scale_global_gemm:**S**
*in* global:**G**
*in* scale_output:**S**
*out* output:**Q**|1+|**F** = tensor(float16)
**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| @@ -1296,7 +1296,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1312,7 +1312,7 @@ Do not modify directly.* |GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| From acfd6117284e9cf0d7fa7a8d56454a49e197f80b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 15 Aug 2024 04:07:07 +0000 Subject: [PATCH 07/13] format js --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 10 +++--- .../jsep/webgpu/ops/multihead-attention.ts | 32 +++++++------------ 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 776944f644a50..8840ef97b4279 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -223,10 +223,12 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte } // TODO: support broadcasting the first and second dimensions of attention_bias - if (attentionBias.dims[0] !== batchSize || - attentionBias.dims[1] !== attributes.numHeads || - attentionBias.dims[2] !== sequenceLength || - attentionBias.dims[3] !== totalSequenceLength) { + if ( + attentionBias.dims[0] !== batchSize || + attentionBias.dims[1] !== attributes.numHeads || + attentionBias.dims[2] !== sequenceLength || + attentionBias.dims[3] !== totalSequenceLength + ) { throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)'); } } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 64f2103713f40..72e09303ba76f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -79,7 +79,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr const batchSize = query.dims[0]; const sequenceLength = query.dims[1]; - const hiddenSize = query.dims.length === 3 ? query.dims[2] : (attributes.numHeads * query.dims[4]); + const hiddenSize = query.dims.length === 3 ? query.dims[2] : attributes.numHeads * query.dims[4]; let kvSequenceLength = sequenceLength; let pastSequenceLength = 0; @@ -147,7 +147,8 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr qkvFormat = AttentionQkvFormat.unknown; // Q_K_V_BSNH_BNSH_BNSH kvSequenceLength = key.dims[2]; } - } else { // packed QKV + } else { + // packed QKV if (query.dims.length !== 5) { throw new Error('Input "query" is expected to have 5 dimensions when key is empty'); } @@ -207,7 +208,8 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)'); } vHiddenSize = value.dims[2]; - } else { // Q_K_V_BSNH_BNSH_BNSH + } else { + // Q_K_V_BSNH_BNSH_BNSH if (kvSequenceLength !== value.dims[2]) { throw new Error('Input "key" and "value" shall have the same dim 2 (kv_sequence_length)'); } @@ -228,10 +230,12 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr } // TODO: support broadcasting the first and second dimensions of attention_bias. - if (attentionBias.dims[0] !== batchSize || - attentionBias.dims[1] !== attributes.numHeads || - attentionBias.dims[2] !== sequenceLength || - attentionBias.dims[3] !== totalSequenceLength) { + if ( + attentionBias.dims[0] !== batchSize || + attentionBias.dims[1] !== attributes.numHeads || + attentionBias.dims[2] !== sequenceLength || + attentionBias.dims[3] !== totalSequenceLength + ) { throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)'); } } @@ -432,17 +436,5 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio 2 * params.hiddenSize, ); - applyAttention( - context, - Q, - K, - V, - keyPaddingMask, - undefined, - pastKey, - pastValue, - attentionBias, - params, - attributes, - ); + applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes); }; From 1eb8c6b3922a859c9f43ac4d0ba3661162405a1b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 15 Aug 2024 08:02:23 +0000 Subject: [PATCH 08/13] refactoring --- onnxruntime/contrib_ops/cpu/bert/attention_base.cc | 3 ++- onnxruntime/contrib_ops/cpu/bert/attention_common.h | 6 ++++-- .../cpu/bert/multihead_attention_helper.h | 3 ++- onnxruntime/contrib_ops/cuda/bert/attention_impl.cu | 9 ++++++--- onnxruntime/contrib_ops/cuda/bert/attention_impl.h | 1 - .../contrib_ops/cuda/bert/attention_prepare_qkv.cu | 12 +++++++++--- .../cuda/bert/cutlass_fmha/fmha_launch_template.h | 12 ++++-------- .../bert/cutlass_fmha/memory_efficient_attention.h | 3 ++- .../cuda/bert/decoder_masked_multihead_attention.cc | 2 -- .../decoder_masked_multihead_attention_impl.cu | 4 ++-- .../decoder_masked_multihead_attention_impl.h | 2 -- .../contrib_ops/cuda/bert/multihead_attention.cc | 1 - .../contrib_ops/cuda/bert/packed_attention.cc | 8 +++++--- onnxruntime/contrib_ops/cuda/bert/packed_attention.h | 4 +++- .../contrib_ops/cuda/bert/packed_attention_impl.cu | 9 ++++++--- .../cuda/bert/packed_multihead_attention.cc | 7 +++++-- .../cuda/bert/packed_multihead_attention_impl.cu | 11 ++++++----- .../test/python/transformers/benchmark_mha.py | 7 +------ 18 files changed, 57 insertions(+), 47 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 21e4b4c7932bc..4573913776e2c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -231,7 +231,8 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->mask_filter_value = mask_filter_value_; output_parameters->scale = scale_; output_parameters->mask_type = mask_type; - output_parameters->attention_bias_dims = attention_bias_dims; + output_parameters->broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1; + output_parameters->broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1; output_parameters->qkv_format = Q_K_V_BNSH; } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 6ea293ea3a870..5a5899166f5ba 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -69,7 +69,8 @@ struct AttentionParameters { bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; - gsl::span attention_bias_dims; + bool broadcast_attn_bias_dim_0; + bool broadcast_attn_bias_dim_1; float mask_filter_value; float scale; bool use_tf32; @@ -89,7 +90,8 @@ struct PackedAttentionParameters { int num_heads; float scale; int token_count; - gsl::span attention_bias_dims; + bool broadcast_attn_bias_dim_0; + bool broadcast_attn_bias_dim_1; bool use_tf32; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 8a644147dbcca..0cfe90963c334 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -427,7 +427,8 @@ Status CheckInputs(const T* query, output_parameters->mask_filter_value = mask_filter_value; output_parameters->mask_type = mask_type; output_parameters->scale = scale; - output_parameters->attention_bias_dims = attention_bias_dims; + output_parameters->broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1; + output_parameters->broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1; output_parameters->qkv_format = qkv_format; } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 5508388e99257..107f08b4c89cc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -361,8 +361,11 @@ Status EfficientAttention( p.query = data.q; p.key = data.k; p.value = data.v; + p.attn_bias = (nullptr == data.attention_bias) ? nullptr : data.attention_bias; - p.attn_bias_dims = data.attention_bias_dims; + p.broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; + p.broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; + p.output = data.output; p.is_kv_bsnh = data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) @@ -438,8 +441,8 @@ Status UnfusedAttention( sequence_length, total_sequence_length); T* scratch2 = data.scratch + (bytes / element_size); - bool broadcast_attn_bias_dim_0 = parameters.attention_bias_dims.size() > 0 && parameters.attention_bias_dims[0] == 1; - bool broadcast_attn_bias_dim_1 = parameters.attention_bias_dims.size() > 1 && parameters.attention_bias_dims[1] == 1; + const bool broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; + const bool broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; // Apply softmax and store result R to scratch2: BxNxSxT if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 29b6c1f53a7e3..a6760f84e69f3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -70,7 +70,6 @@ struct AttentionData { const T* past_key = nullptr; const T* past_value = nullptr; const T* attention_bias = nullptr; - gsl::span attention_bias_dims; bool has_qkv_workspace = false; T* workspace = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index d34e6a92bab03..575e65ebef0e9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -78,15 +78,21 @@ void DumpInputs(contrib::AttentionParameters& parameters, AttentionData& data } if (data.attention_bias != nullptr) { - DUMP_TENSOR_D("attention_bias", data.attention_bias, parameters.attention_bias_dims); + DUMP_TENSOR_D("attention_bias", data.attention_bias, + parameters.broadcast_attn_bias_dim_0 ? 1 : batch_size, + parameters.broadcast_attn_bias_dim_1 ? 1 : num_heads, + sequence_length, + parameters.total_sequence_length); } if (data.mask_index != nullptr) { if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { - DUMP_TENSOR_D("mask", data.mask_index, batch_size, parameters.total_sequence_length); + DUMP_TENSOR_D("mask (2D)", data.mask_index, batch_size, parameters.total_sequence_length); } if (parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { - DUMP_TENSOR_D("mask", data.mask_index, 3 * batch_size + 2, 1); + DUMP_TENSOR_D("mask (seqlen_k)", data.mask_index, 1, batch_size); + DUMP_TENSOR_D("mask (cu_seqlen_q)", data.mask_index + batch_size, 1, batch_size + 1); + DUMP_TENSOR_D("mask (cu_seqlen_k)", data.mask_index + 2 * batch_size + 1, 1, batch_size + 1); } } } diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 39d70bf1ea9bc..1598a7e8bcf1e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -210,15 +210,11 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { } if (params.attn_bias != nullptr) { - auto& bias_dims = params.attn_bias_dims; - ORT_ENFORCE(bias_dims.size() == 4 && - (bias_dims[0] == 1 || bias_dims[0] == params.batch_size) && - (bias_dims[1] == 1 || bias_dims[1] == params.num_heads) && - bias_dims[2] == params.sequence_length && - bias_dims[3] == params.kv_sequence_length); - p.bias_strideH = (bias_dims[1] == 1) ? 0 : p.num_queries * p.num_keys; + p.bias_strideH = params.broadcast_attn_bias_dim_1 ? 0 : p.num_queries * p.num_keys; p.bias_strideM = p.num_keys; - p.bias_strideB = (bias_dims[0] == 1) ? 0 : (bias_dims[1] * p.num_queries * p.num_keys); + p.bias_strideB = params.broadcast_attn_bias_dim_0 + ? 0 + : ((params.broadcast_attn_bias_dim_1 ? 1 : params.num_heads) * p.num_queries * p.num_keys); } else { p.bias_strideH = 0; p.bias_strideM = 0; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index 918eec15f45b1..a9777800f6038 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -36,7 +36,8 @@ struct MemoryEfficientAttentionParams { const void* key; // [B, L, N, H], where L is kv_sequence_length const void* value; // [B, L, N, H_v] const void* attn_bias; // [B or 1, N or 1, S, L] or null - gsl::span attn_bias_dims; + bool broadcast_attn_bias_dim_0; + bool broadcast_attn_bias_dim_1; void* output; // [B, S, N, H_v] void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 3070d98a77cb4..350c4718c437e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -144,8 +144,6 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* // Update the attention bias for self attention if (attention_bias != nullptr) { parameters.attention_bias = const_cast(attention_bias->Data()); - parameters.broadcast_attention_bias_dim_0 = parameters.attention_bias_dims[0] == 1; - parameters.broadcast_attention_bias_dim_1 = parameters.attention_bias_dims[1] == 1; } // Decoder cross-attention diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 2bbaf50ec55da..235b37368ea6b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -158,10 +158,10 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio int64_t attn_bias_offset = 0; if (params.attention_bias != nullptr) { // Support broadcasting the first and second dimensions of attention bias. - if (!params.broadcast_attention_bias_dim_0) { + if (!params.broadcast_attn_bias_dim_0) { attn_bias_offset = static_cast(bbi) * params.num_heads * params.sequence_length * params.total_sequence_length; } - if (!params.broadcast_attention_bias_dim_1) { + if (!params.broadcast_attn_bias_dim_1) { attn_bias_offset += static_cast(hi) * params.sequence_length * params.total_sequence_length; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index 2cb276f82bf18..efad33855328f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -38,8 +38,6 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { void* v_bias = nullptr; void* attention_bias = nullptr; - bool broadcast_attention_bias_dim_0 = false; - bool broadcast_attention_bias_dim_1 = false; void* k_cache = nullptr; void* v_cache = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index da20521fb42d1..b2fd9b5e89de1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -270,7 +270,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); if (nullptr != attention_bias) { data.attention_bias = reinterpret_cast(attention_bias->Data()); - data.attention_bias_dims = attention_bias->Shape().GetDims(); } data.output = reinterpret_cast(output->MutableData()); data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index 2a2df723e4f58..0e5300f32da3c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -43,11 +43,12 @@ TrtFusedAttention::TrtFusedAttention(const OpKernelInfo& info) template MHARunner* TrtFusedAttention::GetFusedRunner(const cudaDeviceProp& device_prop, + bool has_attention_bias, const PackedAttentionParameters& parameters) const { MHARunner* fused_runner = nullptr; bool use_fused_runner = !disable_fused_runner_ && - parameters.attention_bias_dims.empty() && + !has_attention_bias && parameters.hidden_size == parameters.v_hidden_size; if (!use_fused_runner) { @@ -211,7 +212,8 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckAttentionBias( attention_bias_dims, batch_size, num_heads, sequence_length, sequence_length)); } - parameters.attention_bias_dims = attention_bias_dims; + parameters.broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1; + parameters.broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1; parameters.batch_size = static_cast(batch_size); parameters.sequence_length = static_cast(sequence_length); @@ -250,7 +252,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, output_shape); auto& device_prop = this->GetDeviceProp(); - MHARunner* fused_runner = this->GetFusedRunner(device_prop, parameters); + MHARunner* fused_runner = this->GetFusedRunner(device_prop, attention_bias != nullptr, parameters); bool use_memory_efficient_attention = false; #if USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h index cad28e7b70057..6fcacd4d46ada 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h @@ -23,7 +23,9 @@ class TrtFusedAttention : public CudaKernel { TrtFusedAttention(const OpKernelInfo& info); protected: - MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, const PackedAttentionParameters& parameters) const; + MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, + bool has_attention_bias, + const PackedAttentionParameters& parameters) const; protected: const AttentionKernelOptions* kernel_options_; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 890413b82d23f..849a57512dc3d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -523,8 +523,11 @@ Status FusedScaledDotProductAttentionCutlass( p.query = query; p.key = key; p.value = value; + p.attn_bias = data.attention_bias; - p.attn_bias_dims = parameters.attention_bias_dims; + p.broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; + p.broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; + p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? accum_workspace : nullptr; p.stream = stream; @@ -603,8 +606,8 @@ Status UnfusedScaledDotProductAttention( sequence_length); T* attention_score = scaled_qk + (bytes / element_size); - bool broadcast_attn_bias_dim_0 = parameters.attention_bias_dims.size() > 0 && parameters.attention_bias_dims[0] == 1; - bool broadcast_attn_bias_dim_1 = parameters.attention_bias_dims.size() > 1 && parameters.attention_bias_dims[1] == 1; + const bool broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; + const bool broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; // Apply softmax and store result R to attention_score: BxNxSxS ORT_RETURN_IF_ERROR(ComputeSoftmaxWithCumSeqLength( diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index f9714e00c493f..35f43aa9fdc7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -156,7 +156,8 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckAttentionBias( attention_bias_dims, batch_size, num_heads, sequence_length, sequence_length)); } - parameters.attention_bias_dims = attention_bias_dims; + parameters.broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1; + parameters.broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1; parameters.batch_size = static_cast(batch_size); parameters.sequence_length = static_cast(sequence_length); @@ -216,7 +217,9 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co } #endif - MHARunner* fused_runner = use_flash_attention ? nullptr : this->GetFusedRunner(device_prop, parameters); + MHARunner* fused_runner = use_flash_attention + ? nullptr + : this->GetFusedRunner(device_prop, attention_bias != nullptr, parameters); bool use_memory_efficient_attention = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 7bcb589c1d98b..c00eefc8e49de 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -701,8 +701,11 @@ Status FusedAttentionCutlass( p.query = data.no_qkv_workspace ? data.query : data.workspace; p.key = data.no_qkv_workspace ? data.key : (data.workspace + elements_qk); p.value = data.no_qkv_workspace ? data.value : (data.workspace + elements_qk + elements_qk); + p.attn_bias = data.attention_bias; - p.attn_bias_dims = parameters.attention_bias_dims; + p.broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0; + p.broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1; + p.output = data.output; p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) @@ -788,14 +791,12 @@ Status UnfusedAttention( sequence_length); T* attention_score = scaled_qk + (bytes / element_size); - bool broadcast_attn_bias_dim_0 = parameters.attention_bias_dims.size() > 0 && parameters.attention_bias_dims[0] == 1; - bool broadcast_attn_bias_dim_1 = parameters.attention_bias_dims.size() > 1 && parameters.attention_bias_dims[1] == 1; // Apply softmax and store result R to attention_score: BxNxSxS ORT_RETURN_IF_ERROR(ComputeSoftmaxWithCumSeqLength( scaled_qk, data.attention_bias, - broadcast_attn_bias_dim_0, - broadcast_attn_bias_dim_1, + parameters.broadcast_attn_bias_dim_0, + parameters.broadcast_attn_bias_dim_1, data.cumulative_sequence_length, batch_size, sequence_length, diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 791fff2a8969d..c2a89232145c1 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -774,12 +774,7 @@ def run_tflops_test( # flash attention is available for sm >= 80 sm = get_compute_capability() if sm >= 80: - backends = [ - SdpaKernel.DEFAULT, - SdpaKernel.FLASH_ATTENTION, - SdpaKernel.EFFICIENT_ATTENTION, - SdpaKernel.CUDNN_FLASH_ATTENTION, - ] + backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION] else: backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION] else: From 6766b17eae54168b36902c16603f21327aa3f15f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 15 Aug 2024 18:11:50 +0000 Subject: [PATCH 09/13] refactoring cpu; add comments --- .../contrib_ops/cpu/bert/attention_cpu_base.h | 79 ++++++++++--------- .../contrib_ops/cuda/bert/attention_impl.cu | 4 +- .../src/Operators/DmlOperatorAttention.cpp | 11 +-- .../DmlOperatorMultiHeadAttention.cpp | 1 + 4 files changed, 49 insertions(+), 46 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index d9516b6edc2c4..ae2eaf0204026 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -35,7 +35,7 @@ class AttentionCPUBase : public AttentionBase { int qk_head_size, // head size of Q or K (H) int v_head_size, // head size of V (H_v) int v_hidden_size, // hidden size of V (D_v) - const Tensor* attn_bias, // additive bias applied on QK. Its size is BxNxSxT or 1xNxSxT + const Tensor* attn_bias, // additive bias applied on scaled QK. OpKernelContext* context) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -87,7 +87,7 @@ class AttentionCPUBase : public AttentionBase { T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; const T* attn_bias_data = (attn_bias != nullptr) ? attn_bias->Data() : nullptr; - auto attn_bias_shape = (attn_bias != nullptr) ? attn_bias->Shape().GetDims() : gsl::span{}; + auto attn_bias_dims = (attn_bias != nullptr) ? attn_bias->Shape().GetDims() : gsl::span{}; // Compute the attention score. size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T); @@ -97,7 +97,7 @@ class AttentionCPUBase : public AttentionBase { static_cast(mask_data), batch_size, sequence_length, kv_sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, - present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_shape); + present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_dims); // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) auto out_tmp_data = @@ -117,23 +117,24 @@ class AttentionCPUBase : public AttentionBase { // 1 x mask_data(B, N, S, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - T* mask_data, // buffer for mask data. - int batch_size, // batch size of self-attention - int sequence_length, // sequence length of self-attention (S) - int kv_sequence_length, // sequence length of cross-attention (L) - int past_sequence_length, // sequence length of past state - int head_size, // head size of self-attention - const T* past, // past state - const T* past_key, // past key only (if not using past state) - T* present, // present state - T* present_key, // present key only (if not using present state) - ThreadPool* tp, // thread pool - float scale, // scale factor - const T* attn_bias_data, // bias addition matrix with shape BxNxSxT or 1xNxSxT - gsl::span attn_bias_shape) const { + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + T* mask_data, // buffer for mask data. + int batch_size, // batch size of self-attention + int sequence_length, // sequence length of self-attention (S) + int kv_sequence_length, // sequence length of cross-attention (L) + int past_sequence_length, // sequence length of past state + int head_size, // head size of self-attention + const T* past, // past state + const T* past_key, // past key only (if not using past state) + T* present, // present state + T* present_key, // present key only (if not using present state) + ThreadPool* tp, // thread pool + float scale, // scale factor + const T* attn_bias_data, // attention bias + gsl::span attn_bias_dims // attention bias shape + ) const { const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // P x H const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H @@ -143,16 +144,17 @@ class AttentionCPUBase : public AttentionBase { DUMP_CPU_TENSOR_INIT(); DUMP_CPU_TENSOR("Q", Q, batch_size, num_heads_, sequence_length, head_size); DUMP_CPU_TENSOR("K", K, batch_size, num_heads_, total_sequence_length, head_size); - DUMP_CPU_TENSOR("Attn_Bias", attn_bias_data, attn_bias_shape); + DUMP_CPU_TENSOR("Attn_Bias", attn_bias_data, attn_bias_dims); { const int loop_len = batch_size * num_heads_; const float alpha = scale; TensorOpCost unit_cost; - const ptrdiff_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T); + const ptrdiff_t probs_matrix_size = SafeInt(sequence_length) * total_sequence_length; + const ptrdiff_t probs_matrix_bytes = probs_matrix_size * sizeof(T); unit_cost.compute_cycles = - static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length); + static_cast(SafeInt(2) * head_size * probs_matrix_size); unit_cost.bytes_loaded = static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T)); unit_cost.bytes_stored = static_cast(probs_matrix_bytes); @@ -168,7 +170,7 @@ class AttentionCPUBase : public AttentionBase { } if (attn_bias_data != nullptr) { - unit_cost.compute_cycles += static_cast(sequence_length * total_sequence_length); + unit_cost.compute_cycles += static_cast(probs_matrix_size); unit_cost.bytes_loaded += probs_matrix_bytes * 2; unit_cost.bytes_stored += probs_matrix_bytes; } @@ -178,28 +180,27 @@ class AttentionCPUBase : public AttentionBase { const int batch_index = static_cast(i) / num_heads_; const std::ptrdiff_t head_index = i % static_cast(num_heads_); - const ptrdiff_t output_offset = SafeInt(i) * sequence_length * total_sequence_length; - const ptrdiff_t mask_offset = SafeInt(batch_index) * sequence_length * total_sequence_length; + const ptrdiff_t output_offset = SafeInt(i) * probs_matrix_size; + const ptrdiff_t mask_offset = SafeInt(batch_index) * probs_matrix_size; + + T* output = attention_probs + output_offset; - ptrdiff_t attn_bias_offset = 0; if (attn_bias_data != nullptr) { - // broadcast of batch dim with shape (1, N or 1, S, T) - if (attn_bias_shape[0] != 1) { - attn_bias_offset += SafeInt(batch_index) * num_heads_ * sequence_length * total_sequence_length; + // Attention bias has shape (B or 1, N or 1, S, T) + // Here we handle the broadcast of batch_size and num_heads dimensions. + ptrdiff_t attn_bias_offset = 0; + if (attn_bias_dims[0] != 1) { + attn_bias_offset += SafeInt(batch_index) * num_heads_ * probs_matrix_size; } - - // broadcast of head dim with shape (B or 1, 1, S, T) - if (attn_bias_shape[1] != 1) { - attn_bias_offset += head_index * sequence_length * total_sequence_length; + if (attn_bias_dims[1] != 1) { + attn_bias_offset += head_index * probs_matrix_size; } - } - T* output = attention_probs + output_offset; - - if (attn_bias_data != nullptr) { memcpy(output, attn_bias_data + attn_bias_offset, probs_matrix_bytes); + if (mask_data != nullptr) { - for (int j = 0; j < sequence_length * total_sequence_length; j++) { + // This can be optimized with vectorized add using MlasAddFloat32x4. + for (ptrdiff_t j = 0; j < probs_matrix_size; j++) { output[j] += mask_data[mask_offset + j]; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 107f08b4c89cc..28e2b7b28764b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -354,8 +354,8 @@ Status EfficientAttention( p.seqstart_k_ptr = nullptr; } else { p.seqlen_k_ptr = const_cast(reinterpret_cast(data.mask_index)); - p.seqstart_q_ptr = const_cast(reinterpret_cast(data.mask_index + parameters.batch_size)); - p.seqstart_k_ptr = const_cast(reinterpret_cast(data.mask_index + 2 * parameters.batch_size + 1)); + p.seqstart_q_ptr = p.seqlen_k_ptr + parameters.batch_size; + p.seqstart_k_ptr = p.seqlen_k_ptr + 2 * parameters.batch_size + 1; } p.query = data.q; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index f913fb2e02a75..9b4a34622d460 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -74,11 +74,11 @@ class DmlOperatorAttention : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 2); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1); - const uint32_t dmlInputIndex = inputIndex; - const uint32_t dmlWeightsIndex = weightsIndex; - const uint32_t dmlBiasIndex = biasIndex; - const uint32_t dmlMaskIndex = maskIndex; - const uint32_t dmlAttentionBiasIndex = attentionBiasIndex; + constexpr uint32_t dmlInputIndex = inputIndex; + constexpr uint32_t dmlWeightsIndex = weightsIndex; + constexpr uint32_t dmlBiasIndex = biasIndex; + constexpr uint32_t dmlMaskIndex = maskIndex; + constexpr uint32_t dmlAttentionBiasIndex = attentionBiasIndex; const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); @@ -192,6 +192,7 @@ class DmlOperatorAttention : public DmlOperator { auto attentionBiasTensorShape = m_inputTensorDescs[dmlAttentionBiasIndex].GetSizes(); ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape.size() == 4); + // TODO: support broadcast of attention bias on the first and second dimensions. ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape[0] == inputTensorShape[0]); ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape[1] == numHeads); ML_CHECK_VALID_ARGUMENT(attentionBiasTensorShape[2] == inputTensorShape[1]); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp index 1e747e34c9acb..d781aea8515a6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -248,6 +248,7 @@ class DmlOperatorMultiHeadAttention : public DmlOperator ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlAttentionBiasIndex].GetDimensionCount() == 4); auto attentionBiasSizes = m_inputTensorDescs[dmlAttentionBiasIndex].GetSizes(); + // TODO: support broadcast of attention bias on the first and second dimensions. ML_CHECK_VALID_ARGUMENT(attentionBiasSizes[0] == batchSize); ML_CHECK_VALID_ARGUMENT(attentionBiasSizes[1] == numHeads); ML_CHECK_VALID_ARGUMENT(attentionBiasSizes[2] == sequenceLength); From 4984b45558e07071cd492a2af3f77e9b9c331ea3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 16 Aug 2024 06:14:56 +0000 Subject: [PATCH 10/13] refine softmax kernel --- .../contrib_ops/cpu/bert/attention_base.cc | 2 +- .../cuda/bert/attention_softmax.cu | 620 ++++++++---------- .../test/python/transformers/benchmark_mha.py | 256 +++++--- 3 files changed, 442 insertions(+), 436 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 4573913776e2c..52dcb990ab67f 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -38,7 +38,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // bias (Q/K/V) : (D + D + D_v) // mask_index : see below // past (K/V) : (2, B, N, P, H) or NULL - // attention_bias : (B, N, S, T) or NULL + // attention_bias : (B or 1, N or 1, S, T) or NULL // For mask_index, the following shapes are supported: // NULL, (B, 1), (1, 1) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index b11a6aa887039..f4647a514e7e8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -29,9 +29,40 @@ namespace onnxruntime { namespace contrib { namespace attention_softmax_cuda { +#define DISPATCH_BIAS(attn_bias, HAS_BIAS, ...) \ + [&] { \ + const dim3 grid(num_heads * sequence_length, batch_size, 1); \ + if (attn_bias != nullptr) { \ + constexpr static bool HAS_BIAS = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool HAS_BIAS = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// Macro to declare variables: +// offset: offset in input/output +// bias_offset: offset in attn_bias +// b: batch index +// s: sequence index +// grid size is (num_heads * sequence_length, batch_size, 1) +// input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) +// bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) +#define DECLARE_SOFTMAX_VARS() \ + const int s = blockIdx.x % sequence_length; \ + const int b = blockIdx.y; \ + int64_t offset = static_cast(b * gridDim.x + blockIdx.x) * static_cast(total_sequence_length); \ + int64_t bias_offset = 0; \ + if constexpr (HAS_BIAS) { \ + const int j = (broadcast_attn_bias_dim_0 ? 0 : (b * gridDim.x)) + (broadcast_attn_bias_dim_1 ? s : blockIdx.x); \ + bias_offset = static_cast(j) * static_cast(total_sequence_length); \ + } + // This kernel is for non causal, attention mask 1D or None, and total_sequence_length > 1024. -template +template __device__ inline void Softmax(const int total_sequence_length, + const int sequence_length, const int valid_end, const int valid_start, const T* attn_bias, @@ -47,26 +78,7 @@ __device__ inline void Softmax(const int total_sequence_length, float thread_data_max(-CUDART_INF_F); - // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is partitioned to blocks by TPB. - const int sequence_length = gridDim.x; - const int num_heads = gridDim.y; - const int s = blockIdx.x; - const int n = blockIdx.y; - const int b = blockIdx.z; - - // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) - int block = b * num_heads * sequence_length + n * sequence_length + s; - const int64_t offset = static_cast(block) * static_cast(total_sequence_length); - - const bool has_bias = (attn_bias != nullptr); - int64_t bias_offset = 0; - if (has_bias) { - // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) - block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + - (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + - s; - bias_offset = static_cast(block) * static_cast(total_sequence_length); - } + DECLARE_SOFTMAX_VARS(); // e^x is represented as infinity if x is large enough, like 100.f. // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. @@ -74,7 +86,7 @@ __device__ inline void Softmax(const int total_sequence_length, // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) for (int i = threadIdx.x; i < valid_end; i += TPB) { if (i >= valid_start) { - float input_data = has_bias + float input_data = HAS_BIAS ? float(input[offset + i]) + float(attn_bias[bias_offset + i]) : float(input[offset + i]); if (thread_data_max < input_data) { @@ -93,7 +105,7 @@ __device__ inline void Softmax(const int total_sequence_length, float thread_data_sum(0.f); for (int i = threadIdx.x; i < valid_end; i += TPB) { if (i >= valid_start) { - float input_data = has_bias + float input_data = HAS_BIAS ? float(input[offset + i]) + float(attn_bias[bias_offset + i]) : float(input[offset + i]); @@ -109,7 +121,7 @@ __device__ inline void Softmax(const int total_sequence_length, for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { const int index = offset + i; - float input_data = has_bias + float input_data = HAS_BIAS ? float(input[index]) + float(attn_bias[bias_offset + i]) : float(input[index]); const float val = (i >= valid_start && i < valid_end) ? expf(input_data - max_block) * sum_reverse_block : 0.f; @@ -118,8 +130,9 @@ __device__ inline void Softmax(const int total_sequence_length, } // This kernel is for non causal, attention mask 1D or None, and total_sequence_length <= 1024. -template +template __device__ inline void SoftmaxSmall(const int total_sequence_length, + const int sequence_length, const int valid_end, const int valid_start, const T* attn_bias, @@ -134,28 +147,10 @@ __device__ inline void SoftmaxSmall(const int total_sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; - // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is within one block size TPB. - const int sequence_length = gridDim.x; - const int num_heads = gridDim.y; - const int s = blockIdx.x; - const int n = blockIdx.y; - const int b = blockIdx.z; + DECLARE_SOFTMAX_VARS(); - // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) - int block = b * num_heads * sequence_length + n * sequence_length + s; - const int64_t offset = static_cast(block) * static_cast(total_sequence_length); const int index = offset + threadIdx.x; - const bool has_bias = (attn_bias != nullptr); - int64_t bias_offset = 0; - if (has_bias) { - // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) - block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + - (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + - s; - bias_offset = static_cast(block) * static_cast(total_sequence_length); - } - // Update end position for causal. int end = valid_end; if (causal) { @@ -166,7 +161,7 @@ __device__ inline void SoftmaxSmall(const int total_sequence_length, } const bool is_valid = (threadIdx.x >= valid_start && threadIdx.x < end); - float input_data = is_valid ? (has_bias + float input_data = is_valid ? (HAS_BIAS ? float(input[index]) + float(attn_bias[bias_offset + threadIdx.x]) : float(input[index])) : float(-CUDART_INF_F); @@ -203,8 +198,9 @@ __device__ inline void SoftmaxSmall(const int total_sequence_length, } // This kernel is for causal or not, attention mask 1D or None, and total_sequence_length <= 1024. -template +template __global__ void SoftmaxLargeKernel(const int total_sequence_length, + const int sequence_length, const int valid_end, const int valid_start, const T* attn_bias, @@ -221,12 +217,7 @@ __global__ void SoftmaxLargeKernel(const int total_sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; - // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is partitioned by TPB. - const int sequence_length = gridDim.x; - const int num_heads = gridDim.y; - const int s = blockIdx.x; - const int n = blockIdx.y; - const int b = blockIdx.z; + DECLARE_SOFTMAX_VARS(); // Update end position for causal. int end = valid_end; @@ -237,25 +228,11 @@ __global__ void SoftmaxLargeKernel(const int total_sequence_length, } } - // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) - int block = b * num_heads * sequence_length + n * sequence_length + s; - const int64_t offset = static_cast(block) * static_cast(total_sequence_length); - - const bool has_bias = (attn_bias != nullptr); - int64_t bias_offset = 0; - if (has_bias) { - // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) - block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + - (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + - s; - bias_offset = static_cast(block) * static_cast(total_sequence_length); - } - float thread_data_max = -CUDART_INF_F; for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { const int index = offset + i; const bool is_valid = (i >= valid_start && i < end); - float input_data = is_valid ? (has_bias + float input_data = is_valid ? (HAS_BIAS ? float(input[index]) + float(attn_bias[bias_offset + i]) : float(input[index])) : float(-CUDART_INF_F); @@ -292,8 +269,9 @@ __global__ void SoftmaxLargeKernel(const int total_sequence_length, } // This kernel is for causal or not, raw attention mask (2D, 3D or 4D) and total_sequence_length > 1024. -template +template __global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length, + const int sequence_length, const int* attention_mask, // 2D, 3D or 4D attention mask const bool* key_padding_mask, const T* attn_bias, @@ -317,36 +295,14 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length, float max_thread_data = -CUDART_INF_F; - // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is partitioned by TPB. - const int sequence_length = gridDim.x; - const int num_heads = gridDim.y; - const int s = blockIdx.x; - const int n = blockIdx.y; - const int b = blockIdx.z; - - // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) - int block = b * num_heads * sequence_length + n * sequence_length + s; - const int64_t offset = static_cast(block) * static_cast(total_sequence_length); - - const bool has_bias = (attn_bias != nullptr); - int64_t bias_offset = 0; - if (has_bias) { - // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) - block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + - (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + - s; - bias_offset = static_cast(block) * static_cast(total_sequence_length); - } + DECLARE_SOFTMAX_VARS(); for (int i = threadIdx.x; i < total_sequence_length; i += TPB) { - float thread_data = -CUDART_INF_F; int index = offset + i; - if (attn_bias == nullptr) { - thread_data = float(input[index]) * rsqrt_head_size; - } else { - thread_data = (float(input[index]) + float(attn_bias[bias_offset + i])) * rsqrt_head_size; - } - + float input_data = HAS_BIAS + ? float(input[index]) + float(attn_bias[bias_offset + i]) + : float(input[index]); + float thread_data = input_data * rsqrt_head_size; if (causal) { int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length. if (i > from_index) { @@ -414,8 +370,9 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length, } // This kernel is for causal or not, raw attention mask (2D, 3D or 4D), and total_sequence_length <= 1024. -template +template __device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length, + const int sequence_length, const int* attention_mask, // 2D, 3D or 4D attention mask const bool* key_padding_mask, const T* attn_bias, @@ -435,26 +392,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; - // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is within one block size TPB. - const int sequence_length = gridDim.x; - const int num_heads = gridDim.y; - const int s = blockIdx.x; - const int n = blockIdx.y; - const int b = blockIdx.z; - - // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) - int block = b * num_heads * sequence_length + n * sequence_length + s; - const int64_t offset = static_cast(block) * static_cast(total_sequence_length); - - const bool has_bias = (attn_bias != nullptr); - int64_t bias_offset = 0; - if (has_bias) { - // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) - block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + - (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + - s; - bias_offset = static_cast(block) * static_cast(total_sequence_length); - } + DECLARE_SOFTMAX_VARS(); int64_t index = offset + threadIdx.x; @@ -490,7 +428,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length, } } - if (attn_bias != nullptr) { + if (HAS_BIAS) { thread_data += float(attn_bias[bias_offset + threadIdx.x]); } } @@ -524,27 +462,29 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length, } } -template +template __global__ void SoftmaxKernelSmall(const int total_sequence_length, + const int sequence_length, const T* attn_bias, const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, const T* input, T* output, bool causal) { - SoftmaxSmall(total_sequence_length, total_sequence_length, 0, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + SoftmaxSmall(total_sequence_length, sequence_length, total_sequence_length, 0, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); } -template +template __global__ void SoftmaxKernel(const int total_sequence_length, + const int sequence_length, const T* attn_bias, const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, const T* input, T* output) { - Softmax(total_sequence_length, total_sequence_length, 0, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); + Softmax(total_sequence_length, sequence_length, total_sequence_length, 0, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } template @@ -552,49 +492,57 @@ Status ComputeSoftmax(cudaStream_t stream, const int total_sequence_length, cons const int batch_size, const int num_heads, const T* attn_bias, const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, T* input, T* output, bool causal) { - const dim3 grid(sequence_length, num_heads, batch_size); - if (total_sequence_length <= 32) { - const int blockSize = 32; - SoftmaxKernelSmall<<>>( - total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); - } else if (total_sequence_length <= 64) { - const int blockSize = 64; - SoftmaxKernelSmall<<>>( - total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); - } else if (total_sequence_length <= 128) { - const int blockSize = 128; - SoftmaxKernelSmall<<>>( - total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); - } else if (total_sequence_length <= 256) { - const int blockSize = 256; - SoftmaxKernelSmall<<>>( - total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); - } else if (total_sequence_length <= 512) { - const int blockSize = 512; - SoftmaxKernelSmall<<>>( - total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); - } else if (total_sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxKernelSmall<<>>( - total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); - } else if (!causal) { - const int blockSize = 1024; - SoftmaxKernel<<>>( - total_sequence_length, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); - } else { - const int blockSize = 256; - const int sh_bytes = sizeof(float) * total_sequence_length; - SoftmaxLargeKernel<<>>( - total_sequence_length, total_sequence_length, 0, attn_bias, - broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - input, output, true); - } - + DISPATCH_BIAS(attn_bias, HAS_BIAS, [&] { + if (total_sequence_length <= 32) { + const int blockSize = 32; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 64) { + const int blockSize = 64; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 128) { + const int blockSize = 128; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 256) { + const int blockSize = 256; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 512) { + const int blockSize = 512; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (total_sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxKernelSmall<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + } else if (!causal) { + const int blockSize = 1024; + SoftmaxKernel<<>>( + total_sequence_length, sequence_length, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); + } else { + const int blockSize = 256; + const int sh_bytes = sizeof(float) * total_sequence_length; + SoftmaxLargeKernel<<>>( + total_sequence_length, sequence_length, total_sequence_length, 0, attn_bias, + broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, true); + } + }); return CUDA_CALL(cudaGetLastError()); } -template +template __global__ void MaskedSoftmaxKernelSmall(const int total_sequence_length, + const int sequence_length, const int* mask_end, const int* mask_start, const T* attn_bias, @@ -607,7 +555,7 @@ __global__ void MaskedSoftmaxKernelSmall(const int total_sequence_length, __shared__ int end_position; if (threadIdx.x == 0) { - const int batch = blockIdx.z; + const int batch = blockIdx.y; start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; end_position = min(total_sequence_length, mask_end[batch]); @@ -619,12 +567,13 @@ __global__ void MaskedSoftmaxKernelSmall(const int total_sequence_length, } __syncthreads(); - SoftmaxSmall(total_sequence_length, end_position, start_position, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); + SoftmaxSmall(total_sequence_length, sequence_length, end_position, start_position, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal); } -template +template __device__ inline void SoftmaxSmallPacked(const int total_sequence_length, + const int sequence_length, const int end, const T* attn_bias, const bool broadcast_attn_bias_dim_0, @@ -637,33 +586,13 @@ __device__ inline void SoftmaxSmallPacked(const int total_sequence_length, __shared__ float sum_reverse_block; __shared__ float max_block; - - // grid size is (sequence_length, num_heads, batch_size); total_sequence_length is within TPB. - const int sequence_length = gridDim.x; - const int num_heads = gridDim.y; - const int s = blockIdx.x; - const int n = blockIdx.y; - const int b = blockIdx.z; - - // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) - int block = b * num_heads * sequence_length + n * sequence_length + s; - const int64_t offset = static_cast(block) * static_cast(total_sequence_length); - - const bool has_bias = (attn_bias != nullptr); - int64_t bias_offset = 0; - if (has_bias) { - // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) - block = (broadcast_attn_bias_dim_0 ? 0 : (b * num_heads * sequence_length)) + - (broadcast_attn_bias_dim_1 ? 0 : (n * sequence_length)) + - s; - bias_offset = static_cast(block) * static_cast(total_sequence_length); - } + DECLARE_SOFTMAX_VARS(); int64_t index = offset + threadIdx.x; bool is_valid = threadIdx.x < end; - float input_data = has_bias ? float(input[index]) + float(attn_bias[bias_offset + threadIdx.x]) : float(input[index]); + float input_data = HAS_BIAS ? float(input[index]) + float(attn_bias[bias_offset + threadIdx.x]) : float(input[index]); float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F); const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end); @@ -693,48 +622,52 @@ __device__ inline void SoftmaxSmallPacked(const int total_sequence_length, } } -template +template __global__ void SoftmaxKernelSmallWithCumSeqLen(const T* input, const T* attn_bias, const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, const int* cum_seq_length, const int total_sequence_length, + const int sequence_length, T* output) { __shared__ int end_position; if (threadIdx.x == 0) { - const int batch = blockIdx.z; + const int batch = blockIdx.y; end_position = cum_seq_length[batch + 1] - cum_seq_length[batch]; } __syncthreads(); - SoftmaxSmallPacked(total_sequence_length, end_position, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); + SoftmaxSmallPacked(total_sequence_length, sequence_length, end_position, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } -template +template __global__ void SoftmaxKernelWithCumSeqLen(const T* input, const T* attn_bias, const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, const int* cum_seq_length, const int total_sequence_length, + const int sequence_length, T* output) { __shared__ int end_position; if (threadIdx.x == 0) { - const int batch = blockIdx.z; + const int batch = blockIdx.y; end_position = cum_seq_length[batch + 1] - cum_seq_length[batch]; } __syncthreads(); - Softmax(total_sequence_length, end_position, 0 /*start_position*/, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); + constexpr int start_position = 0; + Softmax(total_sequence_length, sequence_length, end_position, start_position, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } -template +template __global__ void MaskedSoftmaxKernel(const int total_sequence_length, + const int sequence_length, const int* mask_end, const int* mask_start, const T* attn_bias, @@ -745,7 +678,7 @@ __global__ void MaskedSoftmaxKernel(const int total_sequence_length, __shared__ int end_position; if (threadIdx.x == 0) { - const int batch = blockIdx.z; + const int batch = blockIdx.y; start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; end_position = min(total_sequence_length, mask_end[batch]); @@ -757,12 +690,13 @@ __global__ void MaskedSoftmaxKernel(const int total_sequence_length, } __syncthreads(); - Softmax(total_sequence_length, end_position, start_position, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); + Softmax(total_sequence_length, sequence_length, end_position, start_position, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output); } -template +template __global__ void SoftmaxWithRawMaskSmallKernel(const int total_sequence_length, + const int sequence_length, const int* attention_mask, const bool* key_padding_mask, const T* attn_bias, @@ -776,8 +710,8 @@ __global__ void SoftmaxWithRawMaskSmallKernel(const int total_sequence_length, const int max_sequence_length, const bool skip_softmax, const float mask_filter_value) { - SoftmaxWithRawMaskSmall( - total_sequence_length, attention_mask, key_padding_mask, + SoftmaxWithRawMaskSmall( + total_sequence_length, sequence_length, attention_mask, key_padding_mask, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal, rsqrt_head_size, mask_dimension, max_sequence_length, skip_softmax, mask_filter_value); @@ -795,44 +729,44 @@ Status ComputeSoftmaxWithCumSeqLength( const int total_sequence_length, const int num_heads, T* output, cudaStream_t stream) { - const dim3 grid(sequence_length, num_heads, batch_size); - - if (sequence_length <= 32) { - const int blockSize = 32; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - cum_seq_length, total_sequence_length, output); - - } else if (sequence_length <= 64) { - const int blockSize = 64; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - cum_seq_length, total_sequence_length, output); - } else if (sequence_length <= 128) { - const int blockSize = 128; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - cum_seq_length, total_sequence_length, output); - } else if (sequence_length <= 256) { - const int blockSize = 256; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - cum_seq_length, total_sequence_length, output); - } else if (sequence_length <= 512) { - const int blockSize = 512; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - cum_seq_length, total_sequence_length, output); - } else if (sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxKernelSmallWithCumSeqLen - <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - cum_seq_length, total_sequence_length, output); - } else { - SoftmaxKernelWithCumSeqLen - <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - cum_seq_length, total_sequence_length, output); - } + DISPATCH_BIAS(attn_bias, HAS_BIAS, [&] { + if (sequence_length <= 32) { + const int blockSize = 32; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else if (sequence_length <= 64) { + const int blockSize = 64; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else if (sequence_length <= 128) { + const int blockSize = 128; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else if (sequence_length <= 256) { + const int blockSize = 256; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else if (sequence_length <= 512) { + const int blockSize = 512; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else if (sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxKernelSmallWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } else { + const int blockSize = 1024; + SoftmaxKernelWithCumSeqLen + <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + cum_seq_length, total_sequence_length, sequence_length, output); + } + }); return CUDA_CALL(cudaGetLastError()); } @@ -851,52 +785,55 @@ Status ComputeSoftmaxWithMask1D(cudaStream_t stream, const T* input, T* output, const bool causal) { - const dim3 grid(sequence_length, num_heads, batch_size); - - if (total_sequence_length <= 32) { - const int blockSize = 32; - MaskedSoftmaxKernelSmall - <<>>(total_sequence_length, mask_index, mask_start, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - input, output, causal); - } else if (total_sequence_length <= 64) { - const int blockSize = 64; - MaskedSoftmaxKernelSmall - <<>>(total_sequence_length, mask_index, mask_start, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - input, output, causal); - } else if (total_sequence_length <= 128) { - const int blockSize = 128; - MaskedSoftmaxKernelSmall - <<>>(total_sequence_length, mask_index, mask_start, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - input, output, causal); - } else if (total_sequence_length <= 256) { - const int blockSize = 256; - MaskedSoftmaxKernelSmall - <<>>(total_sequence_length, mask_index, mask_start, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - input, output, causal); - } else if (total_sequence_length <= 512) { - const int blockSize = 512; - MaskedSoftmaxKernelSmall - <<>>(total_sequence_length, mask_index, mask_start, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - input, output, causal); - } else if (total_sequence_length <= 1024) { - const int blockSize = 1024; - MaskedSoftmaxKernelSmall - <<>>(total_sequence_length, mask_index, mask_start, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - input, output, causal); - } else if (!causal) { - const int blockSize = 1024; - MaskedSoftmaxKernel - <<>>(total_sequence_length, mask_index, mask_start, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, - input, output); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention CUDA operator does not support total sequence length > 1024."); + DISPATCH_BIAS(attn_bias, HAS_BIAS, [&] { + if (total_sequence_length <= 32) { + const int blockSize = 32; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 64) { + const int blockSize = 64; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 128) { + const int blockSize = 128; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 256) { + const int blockSize = 256; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 512) { + const int blockSize = 512; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (total_sequence_length <= 1024) { + const int blockSize = 1024; + MaskedSoftmaxKernelSmall + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output, causal); + } else if (!causal) { + const int blockSize = 1024; + MaskedSoftmaxKernel + <<>>(total_sequence_length, sequence_length, mask_index, mask_start, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, + input, output); + } + }); + + if (total_sequence_length > 1024 && causal) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "ComputeSoftmaxWithMask1D does not support causal with total sequence length > 1024."); } return CUDA_CALL(cudaGetLastError()); @@ -923,61 +860,62 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, T* persistent_softmax_workspace, const float mask_filter_value) { auto stream = static_cast(ort_stream->GetHandle()); - const dim3 grid(sequence_length, num_heads, batch_size); - T* out = use_persistent_softmax ? persistent_softmax_workspace : output; - if (total_sequence_length <= 32) { - const int blockSize = 32; - SoftmaxWithRawMaskSmallKernel - <<>>(total_sequence_length, attention_mask, key_padding_mask, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else if (total_sequence_length <= 64) { - const int blockSize = 64; - SoftmaxWithRawMaskSmallKernel - <<>>(total_sequence_length, attention_mask, key_padding_mask, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else if (total_sequence_length <= 128) { - const int blockSize = 128; - SoftmaxWithRawMaskSmallKernel - <<>>(total_sequence_length, attention_mask, key_padding_mask, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else if (total_sequence_length <= 256) { - const int blockSize = 256; - SoftmaxWithRawMaskSmallKernel - <<>>(total_sequence_length, attention_mask, key_padding_mask, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else if (total_sequence_length <= 512) { - const int blockSize = 512; - SoftmaxWithRawMaskSmallKernel - <<>>(total_sequence_length, attention_mask, key_padding_mask, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else if (total_sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxWithRawMaskSmallKernel - <<>>(total_sequence_length, attention_mask, key_padding_mask, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } else { - const int blockSize = 256; - const int sh_bytes = sizeof(float) * total_sequence_length; - SoftmaxWithRawMaskLargeKernel - <<>>( - total_sequence_length, attention_mask, key_padding_mask, - attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, - out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); - } + + DISPATCH_BIAS(attn_bias, HAS_BIAS, [&] { + if (total_sequence_length <= 32) { + const int blockSize = 32; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else if (total_sequence_length <= 64) { + const int blockSize = 64; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else if (total_sequence_length <= 128) { + const int blockSize = 128; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else if (total_sequence_length <= 256) { + const int blockSize = 256; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else if (total_sequence_length <= 512) { + const int blockSize = 512; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else if (total_sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxWithRawMaskSmallKernel + <<>>(total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } else { + const int blockSize = 256; + const int sh_bytes = sizeof(float) * total_sequence_length; + SoftmaxWithRawMaskLargeKernel + <<>>( + total_sequence_length, sequence_length, attention_mask, key_padding_mask, + attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, + out, causal, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax, mask_filter_value); + } + }); if (use_persistent_softmax) { return onnxruntime::cuda::dispatch_warpwise_softmax_forward( diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index c2a89232145c1..3f2ab31ceba7e 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -630,9 +630,8 @@ def get_cpu_kernel_name(config: MultiHeadAttentionConfig) -> str: # ------------------------------------------------------------------ # Functions for benchmarking PyTorch SDPA # ------------------------------------------------------------------ -def benchmark_torch_function(func: Callable, *args, **kwargs) -> float: +def benchmark_torch_function(repeats: int, func: Callable, *args, **kwargs) -> float: warmup = 5 - repeats = 100 for _ in range(warmup): func(*args, **kwargs) @@ -657,6 +656,7 @@ def run_torch_sdpa( mask_dim: int = 2, mask_dtype=torch.bool, backend: Optional[int] = None, + repeats: int = 100, ): q_shape = (batch_size, num_heads, q_seq_len, head_size) kv_shape = (batch_size, num_heads, kv_seq_len, head_size) @@ -673,6 +673,7 @@ def run_torch_sdpa( with context: average_latency = benchmark_torch_function( + repeats, scaled_dot_product_attention, q, k, @@ -683,7 +684,22 @@ def run_torch_sdpa( return average_latency -def get_test_configs(use_gpu: bool = True): +def get_test_configs(args: argparse.Namespace): + use_gpu: bool = args.use_gpu + + if args.batch_size > 0: + run_unfused = args.sequence_length + args.past_sequence_length <= (2048 if use_gpu else 1024) + return [ + ( + args.batch_size, + args.sequence_length, + args.past_sequence_length, + args.num_heads, + args.head_size, + run_unfused, + ), + ] + if use_gpu: # (batch_size, sequence_length, past_sequence_length, num_heads, head_size, run_unfused) configs = [ @@ -757,13 +773,15 @@ def get_compute_capability(): def run_tflops_test( csv_writer: csv.DictWriter, - use_gpu: bool = True, - enable_cuda_graph: bool = False, - causal: bool = False, - has_past: bool = False, - intra_op_num_threads: int = 0, - repeats: int = 100, + args: argparse.Namespace, ): + use_gpu: bool = args.use_gpu + enable_cuda_graph: bool = args.use_cuda_graph + causal: bool = args.causal + has_past: bool = args.has_past + intra_op_num_threads: int = args.intra_op_num_threads + repeats: int = args.repeats + print(f"run_tflops_test: causal={causal}") if use_gpu: @@ -774,9 +792,9 @@ def run_tflops_test( # flash attention is available for sm >= 80 sm = get_compute_capability() if sm >= 80: - backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION] + backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH] else: - backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION] + backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH] else: device_id = 0 device = torch.device("cpu") @@ -785,7 +803,7 @@ def run_tflops_test( provider = "CPUExecutionProvider" backends = [SdpaKernel.DEFAULT] - configs = get_test_configs(use_gpu) + configs = get_test_configs(args) print("\nformat\tcausal\tprompt\tbatch\tseqlen\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") @@ -798,7 +816,7 @@ def run_tflops_test( num_heads=num_heads, head_size=head_size, causal=causal, - use_kv_cache=use_kv_cache, + use_kv_cache=use_kv_cache, # has present output? past_sequence_length=past_sequence_length, max_cache_sequence_length=None, kv_sequence_length=None, @@ -809,85 +827,87 @@ def run_tflops_test( share_past_present_buffer=False, input_format=input_format, ) - for attention_kernel in backends: - sess_options = SessionOptions() - sess_options.intra_op_num_threads = intra_op_num_threads - session = create_session(config, sess_options, attention_kernel=attention_kernel) - - if use_gpu: - kernel = get_gpu_kernel_name(attention_kernel) - else: - kernel = get_cpu_kernel_name(config) - - if "math" in kernel: - # Skip large sequence length for Unfused kernel to avoid OOM. - if not enable_unfused: - if config.verbose: - print(f"skip unfused kernel for {vars(config)}") - continue - - # Unfused kernel does not support packed QKV or packed KV formats. - if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: - if config.verbose: - print(f"skip input_format for {vars(config)}") + for attention_kernel in backends: + sess_options = SessionOptions() + sess_options.intra_op_num_threads = intra_op_num_threads + session = create_session(config, sess_options, attention_kernel=attention_kernel) + + if use_gpu: + kernel = get_gpu_kernel_name(attention_kernel) + else: + kernel = get_cpu_kernel_name(config) + + if "math" in kernel: + # Skip large sequence length for Unfused kernel to avoid OOM. + if not enable_unfused: + if config.verbose: + print(f"skip unfused kernel for {vars(config)}") + continue + + # Unfused kernel does not support packed QKV or packed KV formats. + if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: + if config.verbose: + print(f"skip input_format for {vars(config)}") + continue + + input_dict = config.random_inputs() + + # warm up session + try: + _ = measure_latency(session, input_dict) + except Exception as e: + print(f"Failed to run {kernel=} for {config=}. Exception: {e}") continue - input_dict = config.random_inputs() - - # warm up session - try: - _ = measure_latency(session, input_dict) - except Exception as e: - print(f"Failed to run {kernel=} for {config=}. Exception: {e}") - continue - - latency_list = [] - for _ in range(repeats): - latency = measure_latency(session, input_dict) - latency_list.append(latency) - average_latency = statistics.mean(latency_list) - - del session - - format_str = InputFormats.input_format_str(input_format) - - # compute TFLOPS per second - speed = None - if past_sequence_length == 0: - speed = tflops_per_second( - flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency + latency_list = [] + for _ in range(repeats): + latency = measure_latency(session, input_dict) + latency_list.append(latency) + average_latency = statistics.mean(latency_list) + + del session + + format_str = InputFormats.input_format_str(input_format) + + # compute TFLOPS per second + speed = None + if past_sequence_length == 0: + speed = tflops_per_second( + flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency + ) + + row = { + "use_gpu": use_gpu, + "enable_cuda_graph": enable_cuda_graph, + "format": format_str, + "causal": causal, + "batch_size": batch_size, + "sequence_length": sequence_length, + "past_sequence_length": past_sequence_length, + "num_heads": num_heads, + "head_size": head_size, + "intra_op_num_threads": intra_op_num_threads, + "average_latency": average_latency, + "tflops": speed, + "kernel": kernel, + } + csv_writer.writerow(row) + + speed = f"{speed:.2f}" if speed is not None else "NA" + print( + f"{format_str}\t{causal}\t{not has_past}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" + f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}" ) - row = { - "use_gpu": use_gpu, - "enable_cuda_graph": enable_cuda_graph, - "format": format_str, - "causal": causal, - "batch_size": batch_size, - "sequence_length": sequence_length, - "past_sequence_length": past_sequence_length, - "num_heads": num_heads, - "head_size": head_size, - "intra_op_num_threads": intra_op_num_threads, - "average_latency": average_latency, - "tflops": speed, - "kernel": kernel, - } - csv_writer.writerow(row) - - speed = f"{speed:.2f}" if speed is not None else "NA" - print( - f"{format_str}\t{causal}\t{not has_past}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" - f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}" - ) - def run_torch_test( csv_writer: csv.DictWriter, - use_gpu: bool = True, - causal: bool = False, + args: argparse.Namespace, ): - configs = get_test_configs(use_gpu) + use_gpu: bool = args.use_gpu + causal: bool = args.causal + + configs = get_test_configs(args) if use_gpu: if not torch.cuda.is_available(): @@ -939,6 +959,7 @@ def run_torch_test( device=device, dtype=dtype, backend=backend, + repeats=args.repeats, ) except RuntimeError: continue @@ -998,16 +1019,9 @@ def run_tflops_tests(args): csv_writer.writeheader() if args.torch: - run_torch_test(csv_writer, args.use_gpu, args.causal) + run_torch_test(csv_writer, args) else: - run_tflops_test( - csv_writer, - use_gpu=args.use_gpu, - enable_cuda_graph=args.use_cuda_graph, - causal=args.causal, - has_past=args.has_past, - intra_op_num_threads=args.intra_op_num_threads, - ) + run_tflops_test(csv_writer, args) def plot_prompt_performance( @@ -1151,6 +1165,60 @@ def _parse_arguments(): ) parser.set_defaults(causal=False) + parser.add_argument( + "-b", + "--batch_size", + required=False, + type=int, + default=0, + help="batch size", + ) + + parser.add_argument( + "-s", + "--sequence_length", + required=False, + type=int, + default=512, + help="sequence length", + ) + + parser.add_argument( + "-p", + "--past_sequence_length", + required=False, + type=int, + default=0, + help="past sequence length", + ) + + parser.add_argument( + "-n", + "--num_heads", + required=False, + type=int, + default=16, + help="number of attention heads", + ) + + parser.add_argument( + "-d", + "--head_size", + required=False, + type=int, + default=64, + help="hidden dimension per head", + ) + + parser.add_argument( + "-r", + "--repeats", + required=False, + type=int, + default=100, + help="number of repeats for performance test", + ) + parser.add_argument( "--torch", required=False, @@ -1181,7 +1249,7 @@ def _parse_arguments(): assert Version(torch.__version__) >= Version("2.3.0") assert args.has_past is False - if args.use_gpu and not args.torch: + if args.use_gpu and args.batch_size == 0 and not args.torch: if platform.system() == "Linux": s = torch.cuda.Stream() with torch.cuda.stream(s), torch.no_grad(): From a7e221b9c2a7854e12dac403fb4be8d7418df9f4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 16 Aug 2024 07:22:36 +0000 Subject: [PATCH 11/13] benchmark mha with attention bias --- .../python/transformers/benchmark_mha.cmd | 8 + .../test/python/transformers/benchmark_mha.py | 232 ++++++++++-------- .../test/python/transformers/benchmark_mha.sh | 9 + 3 files changed, 145 insertions(+), 104 deletions(-) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.cmd b/onnxruntime/test/python/transformers/benchmark_mha.cmd index 0a6d0c37b4a35..ba57ff40203b7 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.cmd +++ b/onnxruntime/test/python/transformers/benchmark_mha.cmd @@ -5,6 +5,14 @@ python benchmark_mha.py --use_gpu python benchmark_mha.py --use_gpu --use_cuda_graph python benchmark_mha.py --use_gpu --torch +echo "Benchmark performance on GPU without attention bias" +python benchmark_mha.py --use_gpu -b 16 + +echo "Benchmark performance on GPU with attention bias" +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 --broadcast_attn_bias_dim_1 + type benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv echo "Benchmark performance on CPU with number of threads:" diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 3f2ab31ceba7e..50b94e7af285e 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -778,7 +778,6 @@ def run_tflops_test( use_gpu: bool = args.use_gpu enable_cuda_graph: bool = args.use_cuda_graph causal: bool = args.causal - has_past: bool = args.has_past intra_op_num_threads: int = args.intra_op_num_threads repeats: int = args.repeats @@ -804,101 +803,106 @@ def run_tflops_test( backends = [SdpaKernel.DEFAULT] configs = get_test_configs(args) - - print("\nformat\tcausal\tprompt\tbatch\tseqlen\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") + print("\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") for input_format in formats: for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: - for use_kv_cache in [False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - use_kv_cache=use_kv_cache, # has present output? - past_sequence_length=past_sequence_length, - max_cache_sequence_length=None, - kv_sequence_length=None, - provider=provider, - enable_cuda_graph=enable_cuda_graph, - device=device, - dtype=torch.float16 if use_gpu else torch.float, - share_past_present_buffer=False, - input_format=input_format, - ) - for attention_kernel in backends: - sess_options = SessionOptions() - sess_options.intra_op_num_threads = intra_op_num_threads - session = create_session(config, sess_options, attention_kernel=attention_kernel) - - if use_gpu: - kernel = get_gpu_kernel_name(attention_kernel) - else: - kernel = get_cpu_kernel_name(config) - - if "math" in kernel: - # Skip large sequence length for Unfused kernel to avoid OOM. - if not enable_unfused: - if config.verbose: - print(f"skip unfused kernel for {vars(config)}") - continue - - # Unfused kernel does not support packed QKV or packed KV formats. - if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: - if config.verbose: - print(f"skip input_format for {vars(config)}") - continue - - input_dict = config.random_inputs() - - # warm up session - try: - _ = measure_latency(session, input_dict) - except Exception as e: - print(f"Failed to run {kernel=} for {config=}. Exception: {e}") + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + use_kv_cache=past_sequence_length > 0, + past_sequence_length=past_sequence_length, + max_cache_sequence_length=None, + kv_sequence_length=None, + provider=provider, + enable_cuda_graph=enable_cuda_graph, + device=device, + dtype=torch.float16 if use_gpu else torch.float, + share_past_present_buffer=False, + input_format=input_format, + has_attn_bias=args.has_attn_bias, + broadcast_attn_bias_dim_0=args.broadcast_attn_bias_dim_0, + broadcast_attn_bias_dim_1=args.broadcast_attn_bias_dim_1, + ) + for attention_kernel in backends: + sess_options = SessionOptions() + sess_options.intra_op_num_threads = intra_op_num_threads + session = create_session(config, sess_options, attention_kernel=attention_kernel) + + if use_gpu: + kernel = get_gpu_kernel_name(attention_kernel) + else: + kernel = get_cpu_kernel_name(config) + + if "math" in kernel: + # Skip large sequence length for Unfused kernel to avoid OOM. + if not enable_unfused: + if config.verbose: + print(f"skip unfused kernel for {vars(config)}") + continue + + # Unfused kernel does not support packed QKV or packed KV formats. + if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: + if config.verbose: + print(f"skip input_format for {vars(config)}") continue - latency_list = [] - for _ in range(repeats): - latency = measure_latency(session, input_dict) - latency_list.append(latency) - average_latency = statistics.mean(latency_list) - - del session - - format_str = InputFormats.input_format_str(input_format) - - # compute TFLOPS per second - speed = None - if past_sequence_length == 0: - speed = tflops_per_second( - flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency - ) - - row = { - "use_gpu": use_gpu, - "enable_cuda_graph": enable_cuda_graph, - "format": format_str, - "causal": causal, - "batch_size": batch_size, - "sequence_length": sequence_length, - "past_sequence_length": past_sequence_length, - "num_heads": num_heads, - "head_size": head_size, - "intra_op_num_threads": intra_op_num_threads, - "average_latency": average_latency, - "tflops": speed, - "kernel": kernel, - } - csv_writer.writerow(row) - - speed = f"{speed:.2f}" if speed is not None else "NA" - print( - f"{format_str}\t{causal}\t{not has_past}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" - f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}" + input_dict = config.random_inputs() + + # warm up session + try: + _ = measure_latency(session, input_dict) + except Exception as e: + print(f"Failed to run {kernel=} for {config=}. Exception: {e}") + continue + + latency_list = [] + for _ in range(repeats): + latency = measure_latency(session, input_dict) + latency_list.append(latency) + average_latency = statistics.mean(latency_list) + + del session + + format_str = InputFormats.input_format_str(input_format) + + # compute TFLOPS per second + speed = None + if past_sequence_length == 0: + speed = tflops_per_second( + flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency ) + row = { + "use_gpu": use_gpu, + "enable_cuda_graph": enable_cuda_graph, + "format": format_str, + "causal": causal, + "batch_size": batch_size, + "sequence_length": sequence_length, + "past_sequence_length": past_sequence_length, + "num_heads": num_heads, + "head_size": head_size, + "has_attn_bias": args.has_attn_bias, + "broadcast_attn_bias_dim_0": args.broadcast_attn_bias_dim_0, + "broadcast_attn_bias_dim_1": args.broadcast_attn_bias_dim_1, + "intra_op_num_threads": intra_op_num_threads, + "average_latency": average_latency, + "tflops": speed, + "kernel": kernel, + } + csv_writer.writerow(row) + + speed = f"{speed:.2f}" if speed is not None else "NA" + print( + f"{format_str}\t{causal}\t{args.has_attn_bias}\t{batch_size}\t" + f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t" + f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}" + ) + def run_torch_test( csv_writer: csv.DictWriter, @@ -967,8 +971,9 @@ def run_torch_test( speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency) input_format = "Q,K,V" print( - f"{input_format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" - f"{0}\t{torch_latency * 1000:.2f}\t{speed:.2f}\t{backend_name}" + f"{input_format}\t{causal}\t{False}\t{batch_size}\t" + f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t" + f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}" ) row = { "use_gpu": use_gpu, @@ -980,6 +985,9 @@ def run_torch_test( "past_sequence_length": past_sequence_length, "num_heads": num_heads, "head_size": head_size, + "has_attn_bias": False, + "broadcast_attn_bias_dim_0": False, + "broadcast_attn_bias_dim_1": False, "intra_op_num_threads": torch.get_num_threads(), "average_latency": torch_latency, "tflops": speed, @@ -992,7 +1000,7 @@ def run_tflops_tests(args): features = "gpu" if args.use_gpu else "cpu" if args.causal: features += "_causal" - if args.has_past: + if args.past_sequence_length > 0: features += "_past" csv_filename = "benchmark_mha_{}_{}_{}.csv".format( features, @@ -1010,6 +1018,9 @@ def run_tflops_tests(args): "past_sequence_length", "num_heads", "head_size", + "has_attn_bias", + "broadcast_attn_bias_dim_0", + "broadcast_attn_bias_dim_1", "intra_op_num_threads", "average_latency", "tflops", @@ -1149,14 +1160,6 @@ def _parse_arguments(): help="intra_op_num_threads for onnxruntime. ", ) - parser.add_argument( - "--has_past", - required=False, - action="store_true", - help="whether past_sequence_length > 0", - ) - parser.set_defaults(has_past=False) - parser.add_argument( "--causal", required=False, @@ -1227,6 +1230,30 @@ def _parse_arguments(): ) parser.set_defaults(torch=False) + parser.add_argument( + "--has_attn_bias", + required=False, + action="store_true", + help="has attention bias", + ) + parser.set_defaults(has_attn_bias=False) + + parser.add_argument( + "--broadcast_attn_bias_dim_0", + required=False, + action="store_true", + help="broadcast attention bias dimension 0", + ) + parser.set_defaults(broadcast_attn_bias_dim_0=False) + + parser.add_argument( + "--broadcast_attn_bias_dim_1", + required=False, + action="store_true", + help="broadcast attention bias dimension 1", + ) + parser.set_defaults(broadcast_attn_bias_dim_1=False) + args = parser.parse_args() return args @@ -1236,9 +1263,6 @@ def _parse_arguments(): args = _parse_arguments() print(f"arguments:{args}") - if args.has_past: - assert args.causal, "--has_past need --causal specified" - if args.use_gpu: assert args.torch or not args.causal, "no causal cuda kernel in MHA op" assert torch.cuda.is_available() @@ -1247,7 +1271,7 @@ def _parse_arguments(): if args.torch: assert Version(torch.__version__) >= Version("2.3.0") - assert args.has_past is False + assert args.past_sequence_length == 0 if args.use_gpu and args.batch_size == 0 and not args.torch: if platform.system() == "Linux": diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh index 613543d0172dd..ff6dd16e698df 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.sh +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -9,6 +9,15 @@ echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" export CUDA_VISIBLE_DEVICES=0 python benchmark_mha.py --use_gpu + +echo "Benchmark BERT-Large performance on GPU without attention bias" +python benchmark_mha.py --use_gpu -b 16 + +echo "Benchmark BERT-Large performance on GPU with attention bias" +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 +python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 --broadcast_attn_bias_dim_1 + python benchmark_mha.py --use_gpu --use_cuda_graph python benchmark_mha.py --use_gpu --torch From 1226c6d04268743f60cb24368c447f5ae54af126 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 16 Aug 2024 08:17:52 +0000 Subject: [PATCH 12/13] mark maybe_unused --- .../cuda/bert/attention_softmax.cu | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index f4647a514e7e8..52f94247a8b2b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -29,16 +29,16 @@ namespace onnxruntime { namespace contrib { namespace attention_softmax_cuda { -#define DISPATCH_BIAS(attn_bias, HAS_BIAS, ...) \ - [&] { \ - const dim3 grid(num_heads * sequence_length, batch_size, 1); \ - if (attn_bias != nullptr) { \ - constexpr static bool HAS_BIAS = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr static bool HAS_BIAS = false; \ - return __VA_ARGS__(); \ - } \ +#define DISPATCH_BIAS(attn_bias, HAS_BIAS, ...) \ + [&] { \ + const dim3 grid(num_heads* sequence_length, batch_size, 1); \ + if (attn_bias != nullptr) { \ + constexpr static bool HAS_BIAS = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool HAS_BIAS = false; \ + return __VA_ARGS__(); \ + } \ }() // Macro to declare variables: @@ -50,10 +50,10 @@ namespace attention_softmax_cuda { // input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length) // bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length) #define DECLARE_SOFTMAX_VARS() \ - const int s = blockIdx.x % sequence_length; \ + [[maybe_unused]] const int s = blockIdx.x % sequence_length; \ const int b = blockIdx.y; \ int64_t offset = static_cast(b * gridDim.x + blockIdx.x) * static_cast(total_sequence_length); \ - int64_t bias_offset = 0; \ + [[maybe_unused]] int64_t bias_offset = 0; \ if constexpr (HAS_BIAS) { \ const int j = (broadcast_attn_bias_dim_0 ? 0 : (b * gridDim.x)) + (broadcast_attn_bias_dim_1 ? s : blockIdx.x); \ bias_offset = static_cast(j) * static_cast(total_sequence_length); \ From 0c7f3952d061f8db812fa3b9ccf30cf4f69e1fb3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 16 Aug 2024 15:51:12 +0000 Subject: [PATCH 13/13] refine attn_bias_offset for dmmha with asummption of S=1 --- .../decoder_masked_multihead_attention_impl.cu | 17 +++++++---------- .../cuda/bert/packed_multihead_attention.cc | 3 +-- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 235b37368ea6b..8edae863ff44e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -155,16 +155,13 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; // The offset of attention bias for current head. - int64_t attn_bias_offset = 0; - if (params.attention_bias != nullptr) { - // Support broadcasting the first and second dimensions of attention bias. - if (!params.broadcast_attn_bias_dim_0) { - attn_bias_offset = static_cast(bbi) * params.num_heads * params.sequence_length * params.total_sequence_length; - } - if (!params.broadcast_attn_bias_dim_1) { - attn_bias_offset += static_cast(hi) * params.sequence_length * params.total_sequence_length; - } - } + // Support broadcasting the first and second dimensions of attention bias with shape + // [batch_size or 1, num_heads or 1, seq_len, total_seq_len], and asssume seq_len == 1 for this operator. + int attn_bias_offset = (params.attention_bias == nullptr) + ? 0 + : (((params.broadcast_attn_bias_dim_0 ? 0 : (bbi * params.num_heads)) + + (params.broadcast_attn_bias_dim_1 ? 0 : hi)) * + params.total_sequence_length); // Trigger the loads from the Q and K buffers. Qk_vec_k q; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 35f43aa9fdc7b..72a4c776d4fce 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -226,9 +226,8 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co #if USE_MEMORY_EFFICIENT_ATTENTION if (!use_flash_attention && nullptr == fused_runner && !disable_memory_efficient_attention_) { int sm = device_prop.major * 10 + device_prop.minor; - bool is_attn_bias_aligned = nullptr == attention_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; use_memory_efficient_attention = - is_attn_bias_aligned && + (nullptr == attention_bias || parameters.sequence_length % (4 * sizeof(T)) == 0) && (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); }