From 8b492e80a8dba314f0be9425cd6c154355b9b7dd Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Mon, 11 Sep 2023 23:37:38 +0400 Subject: [PATCH] Everything works --- js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts | 59 ++++++++++--------- .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 31 +++++----- .../jsep/webgpu/ops/multi-head-attentiion.ts | 14 ++--- .../lib/wasm/jsep/webgpu/program-manager.ts | 1 - onnxruntime/contrib_ops/cpu/bert/attention.cc | 17 ------ 5 files changed, 52 insertions(+), 70 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts index 8db353470e4c7..b4246886048ee 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts @@ -14,25 +14,25 @@ import { ShaderHelper, sumVector, tensorTypeToWsglStorageType -} from './common' -import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose' +} from './common'; +import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'; export enum AttentionQkvFormat { - UNKNOWN, // enum value not set, or depends on qkv projection implementation details - Q_K_V_BNSH, // for non-packed qkv, permuted - Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention - QKV_BSN3H, // for TRT fused attention, qkv are packed - Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) - Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed - Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed. - QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed + unknown, // enum value not set, or depends on qkv projection implementation details + qkvBNSH, // for non-packed qkv, permuted + qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + qkvBSN3H, // for TRT fused attention, qkv are packed + qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed + qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed } export enum AttentionMaskType { - MASK_NONE, // No mask - MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length - MASK_1D_END_START, // [2 * batch_size] with end positions and start positions - MASK_1D_KEY_SEQ_LEN_START, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], + none, // No mask + mask1dKeySeqLen, // [batch_size], key sequence length + mask1dEndStart, // [2 * batch_size] with end positions and start positions + mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., // key_start[batch_size - 1], key_end[batch_size - 1]] MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. @@ -170,7 +170,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte const totalSequenceLength = kvSequenceLength + pastSequenceLength; const maxSequenceLength = -1; - let maskType = AttentionMaskType.MASK_NONE; + let maskType = AttentionMaskType.none; if (maskIndex) { // maskType = AttentionMaskType.MASK_UNKNOWN; // TODO: handle mask @@ -204,7 +204,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte scale: attributes.scale, broadcastResPosBias: false, passPastInKv: false, - qkvFormat: AttentionQkvFormat.Q_K_V_BNSH, + qkvFormat: AttentionQkvFormat.qkvBNSH, }; }; @@ -422,9 +422,8 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: const M: u32 = ${params.sequenceLength}u; const N: u32 = ${params.vHeadSize}u; const K: u32 = ${params.totalSequenceLength / components}u; - const numHeads: u32 = ${params.numHeads}u; const TILE_SIZE = ${TILE_SIZE}u; - + var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; @@ -437,30 +436,32 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; let headIdx = workgroup_id.z; - let m = workgroup_id.y * TILE_SIZE + local_id.y; - let n = workgroup_id.x * TILE_SIZE + local_id.x; + let m = workgroup_id.y * TILE_SIZE; + let n = workgroup_id.x * TILE_SIZE; + let lm = m + local_id.y; + let ln = n + local_id.x; let offsetA = headIdx * (M * K) + m * K; let offsetB = headIdx * (N * K) + n; var value = ${fillVector(dataType, components)}; for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m < M && w + local_id.x < K) { - tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; + if (m + local_id.y < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + local_id.y * K + w + local_id.x]; } - if (n < N && w + local_id.y < K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; + if (n + local_id.y < N && w + local_id.x < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + local_id.y * K + w + local_id.x]; } workgroupBarrier(); for (var k: u32 = 0u; k 65504) { - divisor = `f16(${h / 2}) / 2.0h`; - } - const getMeanShaderSource = (shaderHelper: ShaderHelper) => ` const H: u32 = ${h}; const C: u32 = ${c / components}; @@ -131,8 +126,8 @@ const computeMean = (context: ComputeContext, input: TensorView, scale: TensorVi squaredSum += value * value; } // we need to divide it here to avoid fp16 overflow - sum = sum / ${divisor}; - squaredSum = squaredSum / ${divisor}; + sum = sum / ${wgSize}; + squaredSum = squaredSum / ${wgSize}; output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; }`; @@ -172,6 +167,8 @@ const computeMean = (context: ComputeContext, input: TensorView, scale: TensorVi sum += value[0]; squaredSum += value[1]; } + sum = sum / ${h / wgSize}; + squaredSum = squaredSum / ${h / wgSize}; let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon); let channelScale = invStdDev * scale[currentChannelNumber]; let channelShift = bias[currentChannelNumber] - sum * channelScale; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index 888647fc4c731..e362f587dd2d0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -102,7 +102,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr if (key.dims[2] !== query.dims[2]) { throw new Error('Input \'query\' and \'key\' shall have same dim 2 (hidden_size)'); } - qkvFormat = AttentionQkvFormat.Q_K_V_BSNH; + qkvFormat = AttentionQkvFormat.qkvBSNH; kvSequenceLength = key.dims[1]; } else if (key.dims.length === 5) { if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { @@ -111,14 +111,14 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr if (value) { throw new Error('Expect \'value\' be none when \'key\' has packed kv format.'); } - qkvFormat = AttentionQkvFormat.Q_KV_BSNH_BSN2H; + qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; kvSequenceLength = key.dims[1]; } else { // key_dims.size() == 4 (cross-attention with past_key) if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { throw new Error('Expect \'key\' shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); } - qkvFormat = AttentionQkvFormat.UNKNOWN; + qkvFormat = AttentionQkvFormat.unknown; kvSequenceLength = key.dims[2]; } } else { // packed QKV @@ -129,7 +129,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr throw new Error('Expect \'query\' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); } - qkvFormat = AttentionQkvFormat.QKV_BSN3H; + qkvFormat = AttentionQkvFormat.qkvBSN3H; } if (bias) { @@ -144,15 +144,15 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr } } - let maskType: AttentionMaskType = AttentionMaskType.MASK_NONE; + let maskType: AttentionMaskType = AttentionMaskType.none; if (keyPaddingMask) { maskType = AttentionMaskType.MASK_UNKNOWN; const maskDims = keyPaddingMask.dims; if (maskDims.length === 1) { if (maskDims[0] === batchSize) { - maskType = AttentionMaskType.MASK_1D_KEY_SEQ_LEN; + maskType = AttentionMaskType.mask1dKeySeqLen; } else if (maskDims[0] === 3 * batchSize + 2) { - maskType = AttentionMaskType.MASK_1D_KEY_SEQ_LEN_START + maskType = AttentionMaskType.mask1DKeySeqLenStart } } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) { maskType = AttentionMaskType.MASK_2D_KEY_PADDING; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index d4948108a023c..a230a4bace3ee 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -67,7 +67,6 @@ export class ProgramManager { // usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, // }); // - // // const commandEncoder = this.backend.getCommandEncoder(); // commandEncoder?.copyBufferToBuffer( // output.buffer, diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 3d2f7a29d41f8..4711ccf487cc8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -327,23 +327,6 @@ Status Attention::Compute(OpKernelContext* context) const { }); } - std::cout << "Prepare completed."; - std::cout << "First 10 values at Q: "; - for (size_t i = 0; i < qkv_head_size[0] * sequence_length * batch_size * num_heads_; ++i) { - std::cout << Q[i] << " "; - } - std::cout << std::endl; - std::cout << "First 10 values at K: "; - for (size_t i = 0; i < qkv_head_size[1] * sequence_length * batch_size * num_heads_; ++i) { - std::cout << K[i] << " "; - } - std::cout << std::endl; - std::cout << "First 10 values at V: "; - for (size_t i = 0; i < qkv_head_size[2] * sequence_length * batch_size * num_heads_; ++i) { - std::cout << V[i] << " "; - } - std::cout << std::endl; - // Compute the attention score and apply the score to V return ApplyAttention(Q, K, V, mask_index, past, nullptr /* past_key */, nullptr /* past_value */, output, nullptr /* present_key */, nullptr /* present_value */,