From 557ac74c0565b7cfdb68bb79724d7ed1b870c121 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 10 Jan 2024 01:57:06 +0800 Subject: [PATCH] [js/webgpu] Support gemm uniforms (#19056) ### Description ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/gemm.ts | 123 ++++++++++++------------ 1 file changed, 64 insertions(+), 59 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 1c5d28e4b8e3f..30754c84413b7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -3,10 +3,10 @@ import {TensorView} from '../../tensor-view'; import {GemmUtil, ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs) { @@ -34,25 +34,6 @@ export interface GemmAttributes extends AttributeWithCacheKey { beta: number; } -const offsetC = (m: number, n: number, dims: readonly number[]): string => { - if (dims.length === 0) { - return '0u'; - } - - const broadcastM = (dims.length === 1 && m !== 1) || (dims.length === 2 && dims[0] !== m); - const broadcastN = dims[dims.length - 1] !== n; - - let offset = '0u'; - if (!broadcastM) { - offset += `+ m * ${dims[dims.length - 1]}u`; - } - if (!broadcastN) { - offset += '+n'; - } - - return offset; -}; - const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAttributes): ProgramInfo => { const aShape = inputs[0].dims.slice(); const bShape = inputs[1].dims.slice(); @@ -63,68 +44,92 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt throw new Error('Can\'t use gemm on the given tensors'); } const outputSize = ShapeUtil.size(outputShape); - let line = ''; - if (attributes.transA && attributes.transB) { - line = 'value += a[k * M + m] * b[n * K + k];'; - } else if (attributes.transA && !attributes.transB) { - line = 'value += a[k * M + m] * b[k * N + n];'; - } else if (!attributes.transA && attributes.transB) { - line = 'value += a[m * K + k] * b[n * K + k];'; - } else if (!attributes.transA && !attributes.transB) { - line = 'value += a[m * K + k] * b[k * N + n];'; - } - - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= alpha;'; - const calculateC = inputs.length === 3 ? `value += beta * c[${offsetC(M, N, inputs[2].dims)}];` : ''; - const inputStorageBuffersDeclarations = [ - `@group(0) @binding(0) var a : array<${dataType}>;`, - `@group(0) @binding(1) var b : array<${dataType}>;` + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, {type: 'uint32', data: K}, + {type: 'float32', data: attributes.alpha}, {type: 'float32', data: attributes.beta} ]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; if (inputs.length === 3) { - inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var c : array<${dataType}>;`); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); } - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const M: u32 = ${M}u; - const N: u32 = ${N}u; - const K: u32 = ${K}u; - const alpha = ${dataType}(${attributes.alpha}); - const beta = ${dataType}(${attributes.beta}); + programUniforms.push(...createTensorShapeVariables(outputShape)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + let line = ''; + if (attributes.transA && attributes.transB) { + line = 'value += a[k * uniforms.M + m] * b[n * uniforms.K + k];'; + } else if (attributes.transA && !attributes.transB) { + line = 'value += a[k * uniforms.M + m] * b[k * uniforms.N + n];'; + } else if (!attributes.transA && attributes.transB) { + line = 'value += a[m * uniforms.K + k] * b[n * uniforms.K + k];'; + } else if (!attributes.transA && !attributes.transB) { + line = 'value += a[m * uniforms.K + k] * b[k * uniforms.N + n];'; + } - ${inputStorageBuffersDeclarations.join('\n')} - @group(0) @binding(${inputs.length}) var output : array<${dataType}>; + const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= uniforms.alpha;'; + const a = inputVariable('a', inputs[0].dataType, inputs[0].dims); + const b = inputVariable('b', inputs[1].dataType, inputs[1].dims); + const dataType = a.type.value; + let c: IndicesHelper|null = null; + const variables = [a, b]; + if (inputs.length === 3) { + c = inputVariable('c', inputs[2].dataType, inputs[2].dims.length); + variables.push(c); + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + variables.push(output); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'K', type: 'u32'}, + {name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'} + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - let m = global_idx / N; - let n = global_idx % N; + let m = global_idx / uniforms.N; + let n = global_idx % uniforms.N; var value = ${dataType}(0); - for (var k: u32 = 0u; k<${K}u; k++) { + for (var k: u32 = 0u; k < uniforms.K; k++) { ${line} } ${calculateAlpha} - ${calculateC} + ${(() => { + if (c != null) { + return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += uniforms.beta * ${ + c.getByOffset('cOffset')};`; + } + return ''; + })()} output[global_idx] = value; - }`; + }; + return { name: 'Gemm', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: `${attributes.cacheKey}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }; }; +export const parseGemmAttributes = (attributes: Record): GemmAttributes => { + const transA = attributes.transA as boolean; + const transB = attributes.transB as boolean; + const alpha = attributes.alpha as number; + const beta = attributes.beta as number; + return {transA, transB, alpha, beta, cacheKey: `${attributes.transA};${attributes.transB};${attributes.alpha === 1}`}; +}; + export const gemm = (context: ComputeContext, attributes: GemmAttributes): void => { validateInputs(context.inputs); context.compute(createGemmProgramInfo(context.inputs, attributes)); }; - -export const parseGemmAttributes = (attributes: Record): GemmAttributes => - createAttributeWithCacheKey(attributes as Omit);