From 469e19c1e8ef3135f06b7c29a0e8c8ef1f53abb8 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 5 Nov 2024 07:05:21 +0800 Subject: [PATCH] [js/webgpu] Optimize Gemm (#22706) BUG #22031 The total Gemm time in demucs model becomes 181.14 ms from over 1000 ms on my iGPUs. ### Description ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/gemm.ts | 161 +++++++++++++++++++++++- 1 file changed, 160 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 7f2469d95e1c1..09365f3b984b4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -55,9 +55,15 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt if (!outputShape) { throw new Error("Can't use gemm on the given tensors"); } + const tileSize = 16; + const numTileN = Math.ceil(N / tileSize); + const numTileM = Math.ceil(M / tileSize); + // TODO: Find the condition when to use the naive one. + const useShared = true; + const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: useShared ? numTileN : outputSize }, { type: DataType.uint32, data: M }, { type: DataType.uint32, data: N }, { type: DataType.uint32, data: K }, @@ -130,6 +136,159 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt }`; }; + const getShaderSourceShared = (shaderHelper: ShaderHelper) => { + const a = inputVariable('a', inputs[0].dataType, inputs[0].dims); + const b = inputVariable('b', inputs[1].dataType, inputs[1].dims); + let c: IndicesHelper | null = null; + const variables = [a, b]; + if (inputs.length === 3) { + c = inputVariable('c', inputs[2].dataType, inputs[2].dims.length); + variables.push(c); + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + variables.push(output); + const uniforms: UniformsArrayType = [ + { name: 'num_tile_n', type: 'u32' }, + { name: 'M', type: 'u32' }, + { name: 'N', type: 'u32' }, + { name: 'K', type: 'u32' }, + { name: 'alpha', type: 'f32' }, + { name: 'beta', type: 'f32' }, + ]; + + let calcResult = ''; + let fillWorkgroupMemory = ''; + if (attributes.transA && attributes.transB) { + fillWorkgroupMemory = ` + var col = tile_row_start + local_id.x; + var row = k_start + local_id.y; + if (col < uniforms.M && row < uniforms.K) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = k_start + local_id.x; + row = tile_col_start + local_id.y; + if (col < uniforms.K && row < uniforms.N) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[k][local_id.y] * tile_b[local_id.x][k];`; + } else if (attributes.transA && !attributes.transB) { + fillWorkgroupMemory = ` + var col = tile_row_start + local_id.x; + var row = k_start + local_id.y; + if (col < uniforms.M && row < uniforms.K) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = tile_col_start + local_id.x; + row = k_start + local_id.y; + if (col < uniforms.N && row < uniforms.K) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[k][local_id.y] * tile_b[k][local_id.x];`; + } else if (!attributes.transA && attributes.transB) { + fillWorkgroupMemory = ` + var col = k_start + local_id.x; + var row = tile_row_start + local_id.y; + if (col < uniforms.K && row < uniforms.M) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = k_start + local_id.x; + row = tile_col_start + local_id.y; + if (col < uniforms.K && row < uniforms.N) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[local_id.y][k] * tile_b[local_id.x][k];`; + } else if (!attributes.transA && !attributes.transB) { + fillWorkgroupMemory = ` + var col = k_start + local_id.x; + var row = tile_row_start + local_id.y; + if (col < uniforms.K && row < uniforms.M) { + tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col]; + } else { + tile_a[local_id.y][local_id.x] = ${a.type.value}(0); + } + + col = tile_col_start + local_id.x; + row = k_start + local_id.y; + if (col < uniforms.N && row < uniforms.K) { + tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col]; + } else { + tile_b[local_id.y][local_id.x] = ${b.type.value}(0); + } + `; + calcResult = `value += tile_a[local_id.y][k] * tile_b[k][local_id.x];`; + } + + const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= uniforms.alpha;'; + + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} + var tile_a: array, ${tileSize}>; + var tile_b: array, ${tileSize}>; + ${shaderHelper.mainStart([tileSize, tileSize, 1])} + let tile_col_start = (workgroup_index % uniforms.num_tile_n) * ${tileSize}; + let tile_row_start = (workgroup_index / uniforms.num_tile_n) * ${tileSize}; + let num_tiles = (uniforms.K - 1) / ${tileSize} + 1; + var k_start = 0u; + var value = ${output.type.value}(0); + for (var t: u32 = 0u; t < num_tiles; t++) { + ${fillWorkgroupMemory} + k_start = k_start + ${tileSize}; + workgroupBarrier(); + + for (var k: u32 = 0u; k < ${tileSize}; k++) { + ${calcResult} + } + workgroupBarrier(); + } + + ${calculateAlpha} + let m = tile_row_start + local_id.y; + let n = tile_col_start + local_id.x; + ${(() => { + if (c != null) { + return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${ + output.type.value + }(uniforms.beta) * ${c.getByOffset('cOffset')};`; + } + return ''; + })()} + if (m < uniforms.M && n < uniforms.N) { + output[m * uniforms.N + n] = value; + } + }`; + }; + + if (useShared) { + return { + name: 'GemmShared', + shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: numTileN * numTileM }, + programUniforms, + }), + getShaderSource: getShaderSourceShared, + }; + } + return { name: 'Gemm', shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies },