From 66a3fc04c0b62184e52bc0a49267a857b1dbcce7 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 27 Dec 2023 16:34:14 +0800 Subject: [PATCH] Use mainStart --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 47 +++++++------------- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 4 +- 2 files changed, 18 insertions(+), 33 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index a9b9a4aa2e56e..65f5f01a1b97b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -265,17 +265,17 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView var wgMax: array; var wgSum: array; ${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)} - @compute @workgroup_size(${WG}, 1, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_index) local_index : u32) { - let localOffset = local_index * uniforms.elements_per_wg; + ${shaderHelper.mainStart([ + WG, 1, 1 + ])} + let localOffset = local_idx * uniforms.elements_per_wg; let offset: u32 = workgroup_id.x * uniforms.d_comp + localOffset; var thread_max_vector = ${fillVector('f32', components, '-3.402823e+38f')}; for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { thread_max_vector = max(${castToF32(elemValueType, components, 'x[offset + i]')}, thread_max_vector); } - wgMax[local_index] = ${threadMaxValue}; + wgMax[local_idx] = ${threadMaxValue}; workgroupBarrier(); var maxValue = -3.402823e+38f; @@ -287,7 +287,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { sumVector += exp(${castToF32(elemValueType, components, 'x[offset + i]')} - maxValue); } - wgSum[local_index] = ${sumVector('sumVector', components)}; + wgSum[local_idx] = ${sumVector('sumVector', components)}; workgroupBarrier(); var sum: f32 = 0; @@ -356,21 +356,15 @@ const computeAttentionProbs = {name: 'kv_sequence_length', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType} ]; return ` - const beta: ${dataType} = 1.0; const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; ${shaderHelper.registerUniforms(uniforms).declareVariables(qInput, kInput, output)} - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, - @builtin(local_invocation_index) local_index : u32, - @builtin(num_workgroups) num_workgroups : vec3) { - let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + - workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} // x holds the N and y holds the M let headIdx = workgroup_id.z; let m = workgroup_id.y * TILE_SIZE; @@ -454,14 +448,9 @@ const computeVxAttentionScore = var tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; ${shaderHelper.registerUniforms(uniforms).declareVariables(probsHelper, vHelper, output)} - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, - @builtin(local_invocation_index) local_index : u32, - @builtin(num_workgroups) num_workgroups : vec3) { - let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + - workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} let headIdx = workgroup_id.z; let m = workgroup_id.y * TILE_SIZE + local_id.y; let n = workgroup_id.x * TILE_SIZE + local_id.x; @@ -562,16 +551,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; - ${shaderHelper.registerUniforms(uniforms).declareVariables(input, weight, bias, outputQ, outputK, outputV)} - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, - @builtin(local_invocation_index) local_index : u32, - @builtin(num_workgroups) num_workgroups : vec3) { - let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + - workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} let batchIndex = workgroup_id.z / uniforms.num_heads; let headNumber = workgroup_id.z % uniforms.num_heads; let m = workgroup_id.y * TILE_SIZE + local_id.y; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 0eb0d40a3ea5e..06997487c1608 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -771,8 +771,10 @@ class ShaderHelperImpl implements ShaderHelper { const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1; const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, + @builtin(workgroup_id) workgroup_id : vec3, @builtin(local_invocation_id) local_id : vec3` : - `@builtin(local_invocation_index) local_idx : u32, + `@builtin(local_invocation_id) local_id : vec3, + @builtin(local_invocation_index) local_idx : u32, @builtin(workgroup_id) workgroup_id : vec3, @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ?