diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts index 222a1ce91c95a..744f6d3a04bc4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts @@ -15,7 +15,7 @@ import { sumVector, tensorTypeToWsglStorageType } from './common'; -import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'; +import {transposeProgramMetadata} from './transpose'; export enum AttentionQkvFormat { unknown, // enum value not set, or depends on qkv projection implementation details @@ -211,74 +211,44 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => createAttributeWithCacheKey({...attributes}); -const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); - export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, N: number, D: number) => { - const components = getMaxComponents(D); - const inputHelper = outputVariable('x', input.dataType, input.dims, components); - - let threadMaxValue = 'threadMaxVector'; - if (components === 2) { - threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)'; - } else if (components === 4) { - threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))'; - } const dataType = tensorTypeToWsglStorageType(input.dataType); - let WG = 64; - const dComp = D / components; - if (dComp < WG) { - WG = 1; - } else if (dComp / 8 < 64) { - WG = Math.ceil(dComp / 8); - } - const elementsPerWG = Math.ceil(D / components / WG); - // 6.2.4 in wgsl spec - const threadMaxMinValue = dataType === 'f32' ? '-3.402823e+38f' : '-65504.0h'; const getShaderSource = (shaderHelper: ShaderHelper) => ` const dInv: ${dataType} = 1 / ${D}; - const dComp = ${D / components}; - var wgMax: array<${dataType}, ${WG}>; - var wgSum: array<${dataType}, ${WG}>; - - ${shaderHelper.declareVariables(inputHelper)} - @compute @workgroup_size(${WG}, 1, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_index) local_index : u32) { - let localOffset = local_index * ${elementsPerWG}; - let offset: u32 = workgroup_id.x * dComp + localOffset; - - var threadMaxVector = ${fillVector(dataType, components, threadMaxMinValue)}; - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - threadMaxVector = max(x[offset + i], threadMaxVector); - } - wgMax[local_index] = ${threadMaxValue}; - workgroupBarrier(); + @group(0) @binding(0) var x: array<${dataType}>; + @group(0) @binding(1) var x2: array<${dataType}>; - var maxValue = ${threadMaxMinValue}; - for (var i = 0u; i < ${WG}; i++) { - maxValue = max(wgMax[i], maxValue); + ${shaderHelper.mainStart()} + if (global_idx >= ${N}) { + return; } + let offset: u32 = global_idx * ${D}; - var sumVector = ${fillVector(dataType, components, '0')}; - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - sumVector += exp(x[offset + i] - maxValue); + var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec + for (var i: u32 = 0; i < ${D}; i++) { + threadMax = max(f32(x[offset + i]), threadMax); } - wgSum[local_index] = ${sumVector('sumVector', components)}; - workgroupBarrier(); - var sum: ${dataType} = 0; - for (var i = 0u; i < ${WG}; i++) { - sum += wgSum[i]; + var sum: f32 = 0.0; + for (var i: u32 = 0; i < ${D}; i++) { + let val: f32 = exp(f32(x[offset + i]) - threadMax); + // x[offset + i] = ${dataType}(val); + sum += val; } + // for (var i: u32 = 0; i < ${D}; i++) { + // sum += x[offset + i]; + // } if (sum == 0) { - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - x[offset + i] = ${fillVector(dataType, components, 'dInv')}; + for (var i: u32 = 0; i < ${D}; i++) { + x[offset + i] = dInv; + x2[offset + i] = dInv; } } else { - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - x[offset + i] = exp(x[offset + i] - maxValue) / sum; + for (var i: u32 = 0; i < ${D}; i++) { + x[offset + i] = ${dataType}(exp(f32(x[offset + i]) - threadMax) / sum); + x2[offset + i] = x[offset + i]; } } }`; @@ -288,11 +258,13 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView name: 'computeAttentionProbsSoftmax', cacheHint: '0', inputTypes: [GpuDataType.default], - outputs: [], + outputs: [ + {dims: input.dims, dataType: DataType.float, gpuDataType: GpuDataType.default} + ], getShaderSource, - dispatchGroup: () => ({x: N}) + dispatchGroup: () => ({x: Math.ceil(N / 64)}) }, - {inputs: [input], outputs: []}); + {inputs: [input], outputs: [-1]}); }; const computeAttentionProbs = @@ -409,11 +381,11 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: const dataType = tensorTypeToWsglStorageType(probs.dataType); - const TILE_SIZE = 8; + const TILE_SIZE = 1; const dispatch = { x: Math.ceil(params.vHeadSize / TILE_SIZE), - y: Math.ceil(params.sequenceLength / TILE_SIZE), - z: params.batchSize * params.numHeads + y: Math.ceil(params.totalSequenceLength / TILE_SIZE), + z: params.batchSize * params.numHeads, }; const getShaderSource = (shaderHelper: ShaderHelper) => ` @@ -483,13 +455,35 @@ export const applyAttention = const attentionResult = computeVxAttentionScore(context, probs, v, parameters); + const outputShape = [parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize]; + const input = inputVariable('input', q.dataType, attentionResult.dims); + const output = outputVariable('output', q.dataType, outputShape); + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.declareVariables(input, output)} + + ${shaderHelper.mainStart(parameters.numHeads * parameters.batchSize)} + let batchIndex = global_idx / ${parameters.numHeads}; + let headIndex = global_idx % ${parameters.numHeads}; + // let in = input[0]; + + var inputOffset = ${parameters.sequenceLength * parameters.vHeadSize} * global_idx; + var outputOffset = (batchIndex * ${parameters.sequenceLength * parameters.numHeads} + headIndex) * ${parameters.vHeadSize}; + for (var j = 0; j < ${parameters.sequenceLength}; j++) { + for (var i: u32 = 0; i < ${parameters.vHeadSize}; i++) { + output[outputOffset + i] = input[inputOffset + i]; + } + inputOffset += ${parameters.vHeadSize}; + outputOffset += ${parameters.vHiddenSize}; + } + }`; + context.compute( { ...transposeProgramMetadata, - cacheHint: JSON.stringify(parameters) + JSON.stringify(attributes), - get: () => createTransposeProgramInfo( - attentionResult, weightTransposeAttribute.perm, - [parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize]) + cacheHint: JSON.stringify(parameters), + outputs: [{dims: outputShape, dataType: DataType.float, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({ x: 1 }), }, {inputs: [attentionResult], outputs: [0]}); }; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index a230a4bace3ee..45ee20dfc1364 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -57,40 +57,40 @@ export class ProgramManager { this.backend.pendingDispatchNumber++; - // this.backend.endComputePass(); - // const kernelId = this.backend.currentKernelId!; - // const kernelName = this.backend.kernels.get(kernelId)![0]; - // for (const output of outputs) { - // const stagingBuffer = this.backend.device.createBuffer({ - // size: output.buffer.size, - // // eslint-disable-next-line no-bitwise - // usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, - // }); - // - // const commandEncoder = this.backend.getCommandEncoder(); - // commandEncoder?.copyBufferToBuffer( - // output.buffer, - // 0, // Source offset - // stagingBuffer, - // 0, // Destination offset - // output.buffer.size, - // ); - // this.backend.flush(); - // - // stagingBuffer - // .mapAsync( - // GPUMapMode.READ, - // 0, // Offset - // output.buffer.size, - // ) - // .then(() => { - // const copyArrayBuffer = stagingBuffer.getMappedRange(0, output.buffer.size); - // const data = copyArrayBuffer.slice(0); - // stagingBuffer.unmap(); - // console.log(`${kernelId}|${kernelName}:`); - // console.log(new Float32Array(data)); - // }); - // } + this.backend.endComputePass(); + const kernelId = this.backend.currentKernelId!; + const kernelName = this.backend.kernels.get(kernelId)![0]; + for (const output of outputs) { + const stagingBuffer = this.backend.device.createBuffer({ + size: output.buffer.size, + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + + const commandEncoder = this.backend.getCommandEncoder(); + commandEncoder?.copyBufferToBuffer( + output.buffer, + 0, // Source offset + stagingBuffer, + 0, // Destination offset + output.buffer.size, + ); + this.backend.flush(); + + stagingBuffer + .mapAsync( + GPUMapMode.READ, + 0, // Offset + output.buffer.size, + ) + .then(() => { + const copyArrayBuffer = stagingBuffer.getMappedRange(0, output.buffer.size); + const data = copyArrayBuffer.slice(0); + stagingBuffer.unmap(); + console.log(`${kernelId}|${kernelName}:`); + console.log(new Float32Array(data)); + }); + } if (profilingEnabled) { // profiling write end timestamp diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 4711ccf487cc8..3d2f7a29d41f8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -327,6 +327,23 @@ Status Attention::Compute(OpKernelContext* context) const { }); } + std::cout << "Prepare completed."; + std::cout << "First 10 values at Q: "; + for (size_t i = 0; i < qkv_head_size[0] * sequence_length * batch_size * num_heads_; ++i) { + std::cout << Q[i] << " "; + } + std::cout << std::endl; + std::cout << "First 10 values at K: "; + for (size_t i = 0; i < qkv_head_size[1] * sequence_length * batch_size * num_heads_; ++i) { + std::cout << K[i] << " "; + } + std::cout << std::endl; + std::cout << "First 10 values at V: "; + for (size_t i = 0; i < qkv_head_size[2] * sequence_length * batch_size * num_heads_; ++i) { + std::cout << V[i] << " "; + } + std::cout << std::endl; + // Compute the attention score and apply the score to V return ApplyAttention(Q, K, V, mask_index, past, nullptr /* past_key */, nullptr /* past_value */, output, nullptr /* present_key */, nullptr /* present_value */, diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index b761b1afd8529..040364d256b7b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -200,12 +200,24 @@ class AttentionCPUBase : public AttentionBase { }); } + std::cout << "Probs before softmax."; + for (size_t i = 0; i < total_sequence_length * sequence_length * batch_size * num_heads_; ++i) { + std::cout << attention_probs[i] << " "; + } + std::cout << std::endl; + // attention_probs(B, N, S, T) = Softmax(attention_probs) { const int N = batch_size * num_heads_ * sequence_length; const int D = total_sequence_length; ComputeAttentionSoftmaxInplace(attention_probs, N, D, tp); } + + std::cout << "Probs after softmax."; + for (size_t i = 0; i < total_sequence_length * sequence_length * batch_size * num_heads_; ++i) { + std::cout << attention_probs[i] << " "; + } + std::cout << std::endl; } template