From 2dff8d5d02beb5fcf116b581e1a924f965213ef9 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 25 Oct 2024 07:27:04 -0700 Subject: [PATCH] Use output buffer instead of computing softmax in-place. --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 25 ++++++++++---------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 6a78c8ae3b190..33f25af1ea207 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -288,7 +288,7 @@ const initVarStub = ( } }; -const createInPlaceSoftmaxProgramInfo = ( +const createSoftmaxProgramInfo = ( input: TensorView, batchSize: number, numHeads: number, @@ -324,7 +324,8 @@ const createInPlaceSoftmaxProgramInfo = ( inputDependencies.push('type'); } const getShaderSource = (shaderHelper: ShaderHelper) => { - const inputHelper = outputVariable('x', input.dataType, input.dims, components); + const inputHelper = inputVariable('x', input.dataType, input.dims, components); + const outputHelper = outputVariable('y', input.dataType, input.dims, components); const inputHelpers = [inputHelper]; const seqLensInputHelper = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; if (seqLensInputHelper) { @@ -350,7 +351,7 @@ const createInPlaceSoftmaxProgramInfo = ( return ` var thread_max: array; var thread_sum: array; - ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputHelpers)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputHelpers, outputHelper)} ${shaderHelper.mainStart([WG, 1, 1])} let batchIdx = workgroup_id.z / uniforms.num_heads; let headIdx = workgroup_id.z % uniforms.num_heads; @@ -408,19 +409,19 @@ const createInPlaceSoftmaxProgramInfo = ( if (sum == 0) { for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { - x[offset + i] = ${inputHelper.type.value}(${elemValueType}(1.0) / ${elemValueType}(seq_causal_length)); + y[offset + i] = ${inputHelper.type.value}(${elemValueType}(1.0) / ${elemValueType}(seq_causal_length)); } } else { for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { var f32input = ${f32Type}(x[offset + i]); - x[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum); + y[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum); } } ${ seqLens ? ` for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) { - x[offset + total_seq_id] = ${inputHelper.type.value}(${elemValueType}(0)); + y[offset + total_seq_id] = ${inputHelper.type.value}(${elemValueType}(0)); }` : '' }; @@ -432,7 +433,7 @@ const createInPlaceSoftmaxProgramInfo = ( shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies }, getShaderSource, getRunData: () => ({ - outputs: [], + outputs: [{ dims: input.dims, dataType: input.dataType, gpuDataType: GpuDataType.default }], dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads }, programUniforms, }), @@ -844,8 +845,8 @@ export const applyAttention = ( )[0]; // Run Softmax - context.compute( - createInPlaceSoftmaxProgramInfo( + const softmaxOutput = context.compute( + createSoftmaxProgramInfo( probs, parameters.batchSize, parameters.numHeads, @@ -855,11 +856,11 @@ export const applyAttention = ( seqLens, totalSequenceLengthInput, ), - { inputs: seqLens && totalSequenceLengthInput ? [probs, seqLens, totalSequenceLengthInput] : [probs], outputs: [] }, - ); + { inputs: seqLens && totalSequenceLengthInput ? [probs, seqLens, totalSequenceLengthInput] : [probs], outputs: [-1] }, + )[0]; // Run AttentionScore - const inputsV = [probs, v]; + const inputsV = [softmaxOutput, v]; if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) { inputsV.push(pastValue); }