diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 79b24e9c4d67a..94ad67d3c3b0a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -5,7 +5,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common'; +import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common'; import {createConcatProgramInfo} from './concat'; export const enum AttentionQkvFormat { @@ -351,31 +351,27 @@ const createAttentionProbsProgramInfo = const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize}, {type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads}, - {type: q.dataType, data: alpha} + {type: DataType.float, data: alpha} ]; - const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; - if (relativePositionBias) { - inputDependencies.push('rank'); - programUniforms.push(...createTensorShapeVariables(relativePositionBias.dims)); - } + const inputDependencies: ProgramInputTensorInfoDependency[] = + relativePositionBias ? ['type', 'type', 'type'] : ['type', 'type']; const getShaderSource = (shaderHelper: ShaderHelper) => { const qInput = inputVariable('q', q.dataType, q.dims, components); const kInput = inputVariable('key', key.dataType, key.dims, components); const inputVars = [qInput, kInput]; - const relativePositionBiasInput = relativePositionBias ? - inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims.length) : - undefined; - if (relativePositionBiasInput) { - inputVars.push(relativePositionBiasInput); + if (relativePositionBias) { + inputVars.push( + inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims)); } const output = outputVariable('output', q.dataType, probsShape); - const dataType = tensorTypeToWsglStorageType(q.dataType); + // const dataType = tensorTypeToWsglStorageType(q.dataType); + const f32Type = tensorTypeToWsglValueType(DataType.float, components); const uniforms: UniformsArrayType = [ {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, - {name: 'num_heads', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType} + {name: 'num_heads', type: 'u32'}, {name: 'alpha', type: 'f32' as UniformDataElementType} ]; return ` const TILE_SIZE = ${TILE_SIZE}u; @@ -393,7 +389,7 @@ const createAttentionProbsProgramInfo = let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K; - var value = ${qInput.type.value}(0); + var value = ${f32Type}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) { tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x]; @@ -404,7 +400,7 @@ const createAttentionProbsProgramInfo = workgroupBarrier(); for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) { - value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]; + value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]); } workgroupBarrier(); @@ -413,7 +409,7 @@ const createAttentionProbsProgramInfo = let headOffset = headIdx * uniforms.M * uniforms.N; if (global_id.y < uniforms.M && global_id.x < uniforms.N) { let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x; - var sum = ${(() => { + var sum: f32 = ${(() => { switch (components) { case 1: return 'value'; @@ -425,17 +421,8 @@ const createAttentionProbsProgramInfo = throw new Error(`Unsupported components: ${components}`); } })()}; - - ${(() => { - if (relativePositionBiasInput) { - return ` - let batch = workgroup_id.z / uniforms.num_heads; - let head = workgroup_id.z % uniforms.num_heads; - var indices = ${relativePositionBiasInput.type.indices}(batch, head, global_id.y, global_id.x); - output[outputIdx] = sum * uniforms.alpha + ${relativePositionBiasInput.getByIndices('indices')};`; - } - return 'output[outputIdx] = sum * uniforms.alpha;'; - })()} + output[outputIdx] = ${output.type.value} (sum * uniforms.alpha) + ${ + relativePositionBias ? 'relative_position_bias[outputIdx]' : '0.0'}; } }`; }; @@ -502,7 +489,7 @@ const createVxAttentionScoreProgramInfo = tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * uniforms.N]; } workgroupBarrier(); - for (var k: u32 = 0u; k