From c4ade796d657e179248975423c3e4d3e8f3a6e29 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Sun, 18 Aug 2024 08:15:15 +0800 Subject: [PATCH] [js/webgpu] Fix attention shader recompilation issue (#21770) ### Description This PR fixes the `AttentionProbsSoftmax` recompilation issue when executing the phi3 model. With this fix, it will further improve the phi3 performance. ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 8840ef97b4279..300bc01859def 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -262,9 +262,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor let WG = 64; const dComp = d / components; if (dComp < WG) { - WG = 1; - } else if (dComp / 8 < 64) { - WG = Math.ceil(dComp / 8); + WG = 32; } const elementsPerThread = Math.ceil(d / components / WG); const programUniforms: ProgramUniform[] = [ @@ -274,7 +272,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); const f32Type = tensorTypeToWsglValueType(DataType.float, components); - + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type']; const getShaderSource = (shaderHelper: ShaderHelper) => { const inputHelper = outputVariable('x', input.dataType, input.dims, components); const elemValueType = tensorTypeToWsglValueType(input.dataType); @@ -353,7 +351,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor return { name: 'AttentionProbsSoftmax', - shaderCache: { hint: `${WG};${dataType};${components}` }, + shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies }, getShaderSource, getRunData: () => ({ outputs: [], dispatchGroup: { x: n }, programUniforms }), };