From 070f06a6eb9c344e3dc484b84c6554b7e8c192da Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Tue, 26 Sep 2023 17:17:02 +0400 Subject: [PATCH] Fixes --- js/web/karma.conf.js | 14 ++- .../ops/3rd-party/matmul_packed_webgpu.ts | 7 +- js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts | 47 +++++----- js/web/test/data/ops/attention2.jsonc | 91 ++++++++++++++++++- 4 files changed, 122 insertions(+), 37 deletions(-) diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index a670204cd7ea2..28540ba40cf3a 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -92,7 +92,19 @@ module.exports = function(config) { customLaunchers: { // the following flags are used to make sure Edge on CI agents to initialize WebGPU correctly. EdgeWebGpuTest: {base: 'Edge', flags: ['--ignore-gpu-blocklist', '--gpu-vendor-id=0x10de']}, - ChromeTest: {base: 'Chrome', flags: ['--enable-features=SharedArrayBuffer']}, + ChromeTest: { + base: 'ChromeCanary', + flags: [ + '--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer', + '--enable-experimental-web-platform-features', + '--js-flags=--experimental-wasm-type-reflection --experimental-wasm-stack-switching --experimental-wasm-memory64', + '--enable-dawn-features=allow_unsafe_apis' + ], + chromeFlags: [ + '--experimental-wasm-memory64', '--experimental-wasm-stack-switching', '--experimental-wasm-type-reflection', + '--js-flags=--experimental-wasm-type-reflection --experimental-wasm-stack-switching --experimental-wasm-memory64' + ], + }, ChromeTestHeadless: {base: 'ChromeHeadless', flags: ['--enable-features=SharedArrayBuffer']}, ChromeDebug: { debug: true, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index bc8c27b6a49d3..7c3b012de6f26 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -441,10 +441,7 @@ export const createMatmulProgramInfo = const dimAOuter = aShape[aShape.length - 2]; const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; - // const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; - // TODO: There is bug for input[0]:1,56x56,96, input[1]:1,96,24 in vec4 version. - const isVec4 = false; - const component = isVec4 ? 4 : 1; + const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); // TODO: fine tune size @@ -467,7 +464,7 @@ export const createMatmulProgramInfo = variables.push(output); const inputVariables = [A, B]; const hasBias = inputs.length > 2; - const declareFunctions = matMulReadWriteFnSource(component, hasBias, applyActivation, variables); + const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables); if (hasBias) { inputVariables.push(inputVariable('bias', inputs[2].dataType, [dimBOuter / components], components)); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts index b4246886048ee..222a1ce91c95a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts @@ -318,7 +318,7 @@ const computeAttentionProbs = const N = parameters.totalSequenceLength; const K = vectorizedHeadSize; - const TILE_SIZE = 16; + const TILE_SIZE = 8; const dispatch = { x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), @@ -403,15 +403,13 @@ const computeAttentionProbs = const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { const outputShape = [params.batchSize, params.numHeads, params.sequenceLength, params.vHeadSize]; - const components = 1; //getMaxComponents(params.totalSequenceLength); - const probsHelper = inputVariable('probs', probs.dataType, probs.dims, components); - const vHelper = inputVariable('v', v.dataType, v.dims, components); + const probsHelper = inputVariable('probs', probs.dataType, probs.dims); + const vHelper = inputVariable('v', v.dataType, v.dims); const output = outputVariable('output', probs.dataType, outputShape); const dataType = tensorTypeToWsglStorageType(probs.dataType); - - const TILE_SIZE = 16; + const TILE_SIZE = 8; const dispatch = { x: Math.ceil(params.vHeadSize / TILE_SIZE), y: Math.ceil(params.sequenceLength / TILE_SIZE), @@ -421,9 +419,10 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: const getShaderSource = (shaderHelper: ShaderHelper) => ` const M: u32 = ${params.sequenceLength}u; const N: u32 = ${params.vHeadSize}u; - const K: u32 = ${params.totalSequenceLength / components}u; + const K: u32 = ${params.totalSequenceLength}u; + const numHeads: u32 = ${params.numHeads}u; const TILE_SIZE = ${TILE_SIZE}u; - + var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; @@ -436,33 +435,31 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; let headIdx = workgroup_id.z; - let m = workgroup_id.y * TILE_SIZE; - let n = workgroup_id.x * TILE_SIZE; - let lm = m + local_id.y; - let ln = n + local_id.x; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; let offsetA = headIdx * (M * K) + m * K; let offsetB = headIdx * (N * K) + n; - var value = ${fillVector(dataType, components)}; + var value = ${dataType}(0); for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m + local_id.y < M && w + local_id.x < K) { - tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + local_id.y * K + w + local_id.x]; + if (m < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; } - if (n + local_id.y < N && w + local_id.x < K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + local_id.y * K + w + local_id.x]; + if (n < N && w + local_id.y < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; } workgroupBarrier(); for (var k: u32 = 0u; k