Skip to content

Commit

Permalink
Use mainStart
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Dec 27, 2023
1 parent d47cc8e commit 66a3fc0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 33 deletions.
47 changes: 15 additions & 32 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,17 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
var<workgroup> wgMax: array<f32, ${WG}>;
var<workgroup> wgSum: array<f32, ${WG}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)}
@compute @workgroup_size(${WG}, 1, 1)
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(local_invocation_index) local_index : u32) {
let localOffset = local_index * uniforms.elements_per_wg;
${shaderHelper.mainStart([
WG, 1, 1
])}
let localOffset = local_idx * uniforms.elements_per_wg;
let offset: u32 = workgroup_id.x * uniforms.d_comp + localOffset;
var thread_max_vector = ${fillVector('f32', components, '-3.402823e+38f')};
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
thread_max_vector = max(${castToF32(elemValueType, components, 'x[offset + i]')}, thread_max_vector);
}
wgMax[local_index] = ${threadMaxValue};
wgMax[local_idx] = ${threadMaxValue};
workgroupBarrier();
var maxValue = -3.402823e+38f;
Expand All @@ -287,7 +287,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
sumVector += exp(${castToF32(elemValueType, components, 'x[offset + i]')} - maxValue);
}
wgSum[local_index] = ${sumVector('sumVector', components)};
wgSum[local_idx] = ${sumVector('sumVector', components)};
workgroupBarrier();
var sum: f32 = 0;
Expand Down Expand Up @@ -356,21 +356,15 @@ const computeAttentionProbs =
{name: 'kv_sequence_length', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType}
];
return `
const beta: ${dataType} = 1.0;
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(qInput, kInput, output)}
@compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1)
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(local_invocation_id) local_id : vec3<u32>,
@builtin(local_invocation_index) local_index : u32,
@builtin(num_workgroups) num_workgroups : vec3<u32>) {
let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
${shaderHelper.mainStart([
TILE_SIZE, TILE_SIZE, 1
])}
// x holds the N and y holds the M
let headIdx = workgroup_id.z;
let m = workgroup_id.y * TILE_SIZE;
Expand Down Expand Up @@ -454,14 +448,9 @@ const computeVxAttentionScore =
var<workgroup> tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(probsHelper, vHelper, output)}
@compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1)
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(local_invocation_id) local_id : vec3<u32>,
@builtin(local_invocation_index) local_index : u32,
@builtin(num_workgroups) num_workgroups : vec3<u32>) {
let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
${shaderHelper.mainStart([
TILE_SIZE, TILE_SIZE, 1
])}
let headIdx = workgroup_id.z;
let m = workgroup_id.y * TILE_SIZE + local_id.y;
let n = workgroup_id.x * TILE_SIZE + local_id.x;
Expand Down Expand Up @@ -562,16 +551,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
var<workgroup> tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(input, weight, bias, outputQ, outputK, outputV)}
@compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1)
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(local_invocation_id) local_id : vec3<u32>,
@builtin(local_invocation_index) local_index : u32,
@builtin(num_workgroups) num_workgroups : vec3<u32>) {
let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
${shaderHelper.mainStart([
TILE_SIZE, TILE_SIZE, 1
])}
let batchIndex = workgroup_id.z / uniforms.num_heads;
let headNumber = workgroup_id.z % uniforms.num_heads;
let m = workgroup_id.y * TILE_SIZE + local_id.y;
Expand Down
4 changes: 3 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -771,8 +771,10 @@ class ShaderHelperImpl implements ShaderHelper {

const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1;
const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3<u32>,
@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(local_invocation_id) local_id : vec3<u32>` :
`@builtin(local_invocation_index) local_idx : u32,
`@builtin(local_invocation_id) local_id : vec3<u32>,
@builtin(local_invocation_index) local_idx : u32,
@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(num_workgroups) num_workgroups : vec3<u32>`;
const globalIdxDefinition = is1DimensionDispatch ?
Expand Down

0 comments on commit 66a3fc0

Please sign in to comment.