Skip to content

Commit

Permalink
[JS/WebGU] Support fp16 in Attention by performing the computation in…
Browse files Browse the repository at this point in the history
… fp32. (#20486)

### Description
Perform computation in fp32 and convert finally to fp16.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
satyajandhyala authored Apr 27, 2024
1 parent ff505b9 commit 736cbb3
Showing 1 changed file with 17 additions and 30 deletions.
47 changes: 17 additions & 30 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common';
import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common';
import {createConcatProgramInfo} from './concat';

export const enum AttentionQkvFormat {
Expand Down Expand Up @@ -351,31 +351,27 @@ const createAttentionProbsProgramInfo =
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
{type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads},
{type: q.dataType, data: alpha}
{type: DataType.float, data: alpha}
];

const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
if (relativePositionBias) {
inputDependencies.push('rank');
programUniforms.push(...createTensorShapeVariables(relativePositionBias.dims));
}
const inputDependencies: ProgramInputTensorInfoDependency[] =
relativePositionBias ? ['type', 'type', 'type'] : ['type', 'type'];

const getShaderSource = (shaderHelper: ShaderHelper) => {
const qInput = inputVariable('q', q.dataType, q.dims, components);
const kInput = inputVariable('key', key.dataType, key.dims, components);
const inputVars = [qInput, kInput];
const relativePositionBiasInput = relativePositionBias ?
inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims.length) :
undefined;
if (relativePositionBiasInput) {
inputVars.push(relativePositionBiasInput);
if (relativePositionBias) {
inputVars.push(
inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims));
}
const output = outputVariable('output', q.dataType, probsShape);
const dataType = tensorTypeToWsglStorageType(q.dataType);
// const dataType = tensorTypeToWsglStorageType(q.dataType);
const f32Type = tensorTypeToWsglValueType(DataType.float, components);

const uniforms: UniformsArrayType = [
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType}
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: 'f32' as UniformDataElementType}
];
return `
const TILE_SIZE = ${TILE_SIZE}u;
Expand All @@ -393,7 +389,7 @@ const createAttentionProbsProgramInfo =
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;
var value = ${qInput.type.value}(0);
var value = ${f32Type}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];
Expand All @@ -404,7 +400,7 @@ const createAttentionProbsProgramInfo =
workgroupBarrier();
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k];
value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
}
workgroupBarrier();
Expand All @@ -413,7 +409,7 @@ const createAttentionProbsProgramInfo =
let headOffset = headIdx * uniforms.M * uniforms.N;
if (global_id.y < uniforms.M && global_id.x < uniforms.N) {
let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;
var sum = ${(() => {
var sum: f32 = ${(() => {
switch (components) {
case 1:
return 'value';
Expand All @@ -425,17 +421,8 @@ const createAttentionProbsProgramInfo =
throw new Error(`Unsupported components: ${components}`);
}
})()};
${(() => {
if (relativePositionBiasInput) {
return `
let batch = workgroup_id.z / uniforms.num_heads;
let head = workgroup_id.z % uniforms.num_heads;
var indices = ${relativePositionBiasInput.type.indices}(batch, head, global_id.y, global_id.x);
output[outputIdx] = sum * uniforms.alpha + ${relativePositionBiasInput.getByIndices('indices')};`;
}
return 'output[outputIdx] = sum * uniforms.alpha;';
})()}
output[outputIdx] = ${output.type.value} (sum * uniforms.alpha) + ${
relativePositionBias ? 'relative_position_bias[outputIdx]' : '0.0'};
}
}`;
};
Expand Down Expand Up @@ -502,7 +489,7 @@ const createVxAttentionScoreProgramInfo =
tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * uniforms.N];
}
workgroupBarrier();
for (var k: u32 = 0u; k<TILE_SIZE && w+k < uniforms.K; k++) {
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
}
workgroupBarrier();
Expand All @@ -512,7 +499,7 @@ const createVxAttentionScoreProgramInfo =
let batchIdx = workgroup_id.z / uniforms.num_heads;
let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;
if (m < uniforms.M && n < uniforms.N) {
let outputIdx = batchIdx * uniforms.M *uniforms.v_hidden_size + m * uniforms.v_hidden_size
let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size
+ currentBatchHeadNumber * uniforms.N + n;
output[outputIdx] = value;
}
Expand Down

0 comments on commit 736cbb3

Please sign in to comment.