From 8475617b75f568b2c51bbe513d3968d3ecdaa8c2 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Mon, 11 Sep 2023 21:28:16 +0400 Subject: [PATCH] Attention optimizations WIP --- js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts | 102 ++++++++++++------ onnxruntime/contrib_ops/js/bias_add.cc | 4 +- onnxruntime/contrib_ops/js/bias_split_gelu.cc | 4 +- onnxruntime/contrib_ops/js/gelu.cc | 4 +- 4 files changed, 77 insertions(+), 37 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts index ca5e60d1e8050..d53d65448e768 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts @@ -318,7 +318,7 @@ const computeAttentionProbs = const N = parameters.totalSequenceLength; const K = vectorizedHeadSize; - const TILE_SIZE = Math.min(8, vectorizedHeadSize); + const TILE_SIZE = 16; const dispatch = { x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), @@ -403,13 +403,15 @@ const computeAttentionProbs = const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { const outputShape = [params.batchSize, params.numHeads, params.sequenceLength, params.vHeadSize]; - const probsHelper = inputVariable('probs', probs.dataType, probs.dims); - const vHelper = inputVariable('v', v.dataType, v.dims); + const components = getMaxComponents(params.totalSequenceLength); + const probsHelper = inputVariable('probs', probs.dataType, probs.dims, components); + const vHelper = inputVariable('v', v.dataType, v.dims, components); const output = outputVariable('output', probs.dataType, outputShape); const dataType = tensorTypeToWsglStorageType(probs.dataType); - const TILE_SIZE = 8; + + const TILE_SIZE = 16; const dispatch = { x: Math.ceil(params.vHeadSize / TILE_SIZE), y: Math.ceil(params.sequenceLength / TILE_SIZE), @@ -419,7 +421,7 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: const getShaderSource = (shaderHelper: ShaderHelper) => ` const M: u32 = ${params.sequenceLength}u; const N: u32 = ${params.vHeadSize}u; - const K: u32 = ${params.totalSequenceLength}u; + const K: u32 = ${params.totalSequenceLength / components}u; const numHeads: u32 = ${params.numHeads}u; const TILE_SIZE = ${TILE_SIZE}u; @@ -441,7 +443,7 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: let offsetA = headIdx * (M * K) + m * K; let offsetB = headIdx * (N * K) + n; - var value = ${dataType}(0); + var value = ${fillVector(dataType, components)}; for (var w: u32 = 0u; w < K; w += TILE_SIZE) { if (m < M && w + local_id.x < K) { tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; @@ -459,7 +461,7 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: let headOffset = headIdx * M * N; if (m < M && n < N) { let outputIdx = headOffset + m * N + n; - output[outputIdx] = value; + output[outputIdx] = ${sumVector('value', components)}; } }`; @@ -504,20 +506,31 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters, attri // TODO: handle mask // const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; - const gemmSize = parameters.sequenceLength * parameters.hiddenSize; - const unitsOfWork = gemmSize * parameters.batchSize * parameters.numHeads; const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); const M = parameters.sequenceLength; const K = parameters.inputHiddenSize; const N = parameters.headSize; + const TILE_SIZE = 16; + const dispatch = { + x: Math.ceil(parameters.headSize / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + const getShaderSource = (shaderHelper: ShaderHelper) => ` const M: u32 = ${M}u; const K: u32 = ${K}u; const N: u32 = ${N}u; const numHeads: u32 = ${parameters.numHeads}; const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; @group(0) @binding(0) var input: array<${dataType}>; @group(0) @binding(1) var weight: array<${dataType}>; @@ -526,39 +539,60 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters, attri @group(0) @binding(4) var outputK: array<${dataType}>; @group(0) @binding(5) var outputV: array<${dataType}>; - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)} - let gemmSize = M * N; - let idxWoGemmSize = global_idx / gemmSize; - let batchIndex = idxWoGemmSize / numHeads; - let headIndex = idxWoGemmSize % numHeads; + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - let gemmOffset = global_idx % gemmSize; - let m = gemmOffset / N; - let n = gemmOffset % N; + let batchIndex = workgroup_id.z / ${parameters.batchSize}; + let headNumber = workgroup_id.z % ${parameters.batchSize}; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; let inputOffset = batchIndex * ${parameters.sequenceLength * parameters.inputHiddenSize} + m * K; - let biasOffsetQ = headIndex * ${parameters.headSize}; + let biasOffsetQ = headNumber * ${parameters.headSize}; let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ; let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK; - var value = vec3<${dataType}>(0, 0, 0); - for (var k: u32 = 0u; k<${K}u; k++) { - let a = input[k + inputOffset]; - let itemWeightOffset = k * ldb + n; - value[0] += a * weight[itemWeightOffset + biasOffsetQ]; - value[1] += a * weight[itemWeightOffset + biasOffsetK]; - value[2] += a * weight[itemWeightOffset + biasOffsetV]; + var valueQ = ${dataType}(0); + var valueK = ${dataType}(0); + var valueV = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + (w + local_id.y) * ldb]; + } + if (n < N && w + local_id.y < K) { + tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + (w + local_id.y) * ldb]; + } + if (n < N && w + local_id.y < K) { + tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + (w + local_id.y) * ldb]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}) + dispatchGroup: () => (dispatch) }, {inputs, outputs: [-1, -1, -1]}); }; diff --git a/onnxruntime/contrib_ops/js/bias_add.cc b/onnxruntime/contrib_ops/js/bias_add.cc index 51a1073f5ce56..9e70dead6a5da 100644 --- a/onnxruntime/contrib_ops/js/bias_add.cc +++ b/onnxruntime/contrib_ops/js/bias_add.cc @@ -7,13 +7,15 @@ namespace onnxruntime { namespace contrib { namespace js { +using onnxruntime::js::JsepSupportedFloatTypes; + ONNX_OPERATOR_KERNEL_EX( BiasAdd, kMSDomain, 1, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()), BiasAdd); } // namespace js diff --git a/onnxruntime/contrib_ops/js/bias_split_gelu.cc b/onnxruntime/contrib_ops/js/bias_split_gelu.cc index efc52af2330ba..e16aa4367d1c7 100644 --- a/onnxruntime/contrib_ops/js/bias_split_gelu.cc +++ b/onnxruntime/contrib_ops/js/bias_split_gelu.cc @@ -7,13 +7,15 @@ namespace onnxruntime { namespace contrib { namespace js { +using onnxruntime::js::JsepSupportedFloatTypes; + ONNX_OPERATOR_KERNEL_EX( BiasSplitGelu, kMSDomain, 1, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()), BiasSplitGelu); } // namespace js diff --git a/onnxruntime/contrib_ops/js/gelu.cc b/onnxruntime/contrib_ops/js/gelu.cc index 57de4e21a200e..3f4a0275fa532 100644 --- a/onnxruntime/contrib_ops/js/gelu.cc +++ b/onnxruntime/contrib_ops/js/gelu.cc @@ -7,13 +7,15 @@ namespace onnxruntime { namespace contrib { namespace js { +using onnxruntime::js::JsepSupportedFloatTypes; + ONNX_OPERATOR_KERNEL_EX( Gelu, kMSDomain, 1, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()), Gelu); } // namespace js