Skip to content

Commit

Permalink
[js/webgpu] Support gemm uniforms (#19056)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
axinging authored Jan 9, 2024
1 parent 42ba2ae commit 557ac74
Showing 1 changed file with 64 additions and 59 deletions.
123 changes: 64 additions & 59 deletions js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import {TensorView} from '../../tensor-view';
import {GemmUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs) {
Expand Down Expand Up @@ -34,25 +34,6 @@ export interface GemmAttributes extends AttributeWithCacheKey {
beta: number;
}

const offsetC = (m: number, n: number, dims: readonly number[]): string => {
if (dims.length === 0) {
return '0u';
}

const broadcastM = (dims.length === 1 && m !== 1) || (dims.length === 2 && dims[0] !== m);
const broadcastN = dims[dims.length - 1] !== n;

let offset = '0u';
if (!broadcastM) {
offset += `+ m * ${dims[dims.length - 1]}u`;
}
if (!broadcastN) {
offset += '+n';
}

return offset;
};

const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAttributes): ProgramInfo => {
const aShape = inputs[0].dims.slice();
const bShape = inputs[1].dims.slice();
Expand All @@ -63,68 +44,92 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
throw new Error('Can\'t use gemm on the given tensors');
}
const outputSize = ShapeUtil.size(outputShape);
let line = '';
if (attributes.transA && attributes.transB) {
line = 'value += a[k * M + m] * b[n * K + k];';
} else if (attributes.transA && !attributes.transB) {
line = 'value += a[k * M + m] * b[k * N + n];';
} else if (!attributes.transA && attributes.transB) {
line = 'value += a[m * K + k] * b[n * K + k];';
} else if (!attributes.transA && !attributes.transB) {
line = 'value += a[m * K + k] * b[k * N + n];';
}

const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= alpha;';
const calculateC = inputs.length === 3 ? `value += beta * c[${offsetC(M, N, inputs[2].dims)}];` : '';
const inputStorageBuffersDeclarations = [
`@group(0) @binding(0) var<storage, read> a : array<${dataType}>;`,
`@group(0) @binding(1) var<storage, read> b : array<${dataType}>;`
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, {type: 'uint32', data: K},
{type: 'float32', data: attributes.alpha}, {type: 'float32', data: attributes.beta}
];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
if (inputs.length === 3) {
inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var<storage, read> c : array<${dataType}>;`);
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');
}
const getShaderSource = (shaderHelper: ShaderHelper) => `
const M: u32 = ${M}u;
const N: u32 = ${N}u;
const K: u32 = ${K}u;
const alpha = ${dataType}(${attributes.alpha});
const beta = ${dataType}(${attributes.beta});
programUniforms.push(...createTensorShapeVariables(outputShape));

const getShaderSource = (shaderHelper: ShaderHelper) => {
let line = '';
if (attributes.transA && attributes.transB) {
line = 'value += a[k * uniforms.M + m] * b[n * uniforms.K + k];';
} else if (attributes.transA && !attributes.transB) {
line = 'value += a[k * uniforms.M + m] * b[k * uniforms.N + n];';
} else if (!attributes.transA && attributes.transB) {
line = 'value += a[m * uniforms.K + k] * b[n * uniforms.K + k];';
} else if (!attributes.transA && !attributes.transB) {
line = 'value += a[m * uniforms.K + k] * b[k * uniforms.N + n];';
}

${inputStorageBuffersDeclarations.join('\n')}
@group(0) @binding(${inputs.length}) var<storage, read_write> output : array<${dataType}>;
const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= uniforms.alpha;';
const a = inputVariable('a', inputs[0].dataType, inputs[0].dims);
const b = inputVariable('b', inputs[1].dataType, inputs[1].dims);
const dataType = a.type.value;
let c: IndicesHelper|null = null;
const variables = [a, b];
if (inputs.length === 3) {
c = inputVariable('c', inputs[2].dataType, inputs[2].dims.length);
variables.push(c);
}
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
variables.push(output);
const uniforms: UniformsArrayType = [
{name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'K', type: 'u32'},
{name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'}
];
return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let m = global_idx / N;
let n = global_idx % N;
let m = global_idx / uniforms.N;
let n = global_idx % uniforms.N;
var value = ${dataType}(0);
for (var k: u32 = 0u; k<${K}u; k++) {
for (var k: u32 = 0u; k < uniforms.K; k++) {
${line}
}
${calculateAlpha}
${calculateC}
${(() => {
if (c != null) {
return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += uniforms.beta * ${
c.getByOffset('cOffset')};`;
}
return '';
})()}
output[global_idx] = value;
}`;
};

return {
name: 'Gemm',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {hint: `${attributes.cacheKey}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource,
};
};

export const parseGemmAttributes = (attributes: Record<string, unknown>): GemmAttributes => {
const transA = attributes.transA as boolean;
const transB = attributes.transB as boolean;
const alpha = attributes.alpha as number;
const beta = attributes.beta as number;
return {transA, transB, alpha, beta, cacheKey: `${attributes.transA};${attributes.transB};${attributes.alpha === 1}`};
};

export const gemm = (context: ComputeContext, attributes: GemmAttributes): void => {
validateInputs(context.inputs);
context.compute(createGemmProgramInfo(context.inputs, attributes));
};

export const parseGemmAttributes = (attributes: Record<string, unknown>): GemmAttributes =>
createAttributeWithCacheKey(attributes as Omit<GemmAttributes, keyof AttributeWithCacheKey>);

0 comments on commit 557ac74

Please sign in to comment.