Skip to content

Commit

Permalink
[js/web] allow ShaderHelper to use internal (non-I/O) variables (#18525)
Browse files Browse the repository at this point in the history
### Description
This PR includes a change that inspired from #18452 to resolve a
requirement: a shader may depend on an instance of `IndicesHelper` to
generate WGSL code snippet, but the IndicesHelper instance is not
necessarily an input/output of the program. So the existing
`declareVariables()` function does not work with this scenario.

In order to support this requirement, I added this "use" function to
`interface ShaderHelper`, which takes a helper-like object as parameter.
The hidden implementation `ShaderHelperImpl` class will iterate the
helpers and call `impl()` for each.

@axinging @qjia7
  • Loading branch information
fs-eire authored Nov 28, 2023
1 parent a49f31b commit 50e6235
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 51 deletions.
26 changes: 9 additions & 17 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils';

import {typeSnippet} from './activation_util';
Expand Down Expand Up @@ -341,13 +341,8 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
const matMulReadWriteFnSource =
(component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[],
batchShapes: Array<readonly number[]>, isChannelsLast = false): string => {
const batchAShape = batchShapes[0];
const batchBShape = batchShapes[1];
const batchShape = batchShapes[2];
const batchVariable = variables[0];
const aVariable = variables[1];
const bVariable = variables[2];
const outputVariable = variables[3];
const [batchAShape, batchBShape, batchShape] = batchShapes;
const [batchVariable, aVariable, bVariable, outputVariable] = variables;
const broadCastADims = getBroadcastDims(batchAShape, batchShape);
const broadCastBDims = getBroadcastDims(batchBShape, batchShape);
const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor);
Expand Down Expand Up @@ -434,9 +429,7 @@ export const createMatmulProgramInfo =
const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
const enableBatchUniforms = enableShapesUniforms(outerDims.length);
const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims;
const batchDims = inputVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1, true);
const variables = [batchDims];
const batchShapes = [outerDimsA, outerDimsB, outerDims];
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1);
const batchSize = ShapeUtil.size(outerDims);

const dimAOuter = aShape[aShape.length - 2];
Expand Down Expand Up @@ -469,10 +462,7 @@ export const createMatmulProgramInfo =
const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components);
const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components);
const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components);
variables.push(A);
variables.push(B);
variables.push(output);
const inputVariables = [batchDims, A, B];
const inputVariables = [A, B];
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
if (enableBatchUniforms) {
Expand All @@ -490,8 +480,9 @@ export const createMatmulProgramInfo =

const hasBias = inputs.length > 2;
const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value);
const declareFunctions =
matMulReadWriteFnSource(components, hasBias, applyActivation, variables, batchShapes, isChannelsLast);
const declareFunctions = matMulReadWriteFnSource(
components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims],
isChannelsLast);
if (hasBias) {
const biasComponents = isChannelsLast ? components : 1;
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
Expand All @@ -506,6 +497,7 @@ export const createMatmulProgramInfo =
shaderHelper.registerUniform('dimAOuter', 'i32')
.registerUniform('dimBOuter', 'i32')
.registerUniform('dimInner', 'i32')
.registerInternalVariables(batchDims)
.declareVariables(...inputVariables, output)}
${activationFunction}
${declareFunctions}
Expand Down
108 changes: 74 additions & 34 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ interface IndicesHelperTypes {
* create an instance of an indices helper:
* - `inputVariable()`: create an indices helper instance for an input.
* - `outputVariable()`: create an indices helper instance for an output.
* - `internalVariable()`: create an indices helper instance for an internal variable.
*
* An indices helper instance contains helper functions for the following operations:
* - access readonly basic information, including: `name`(the name of the input or output), `usage`(whether it's an
* input or an output) and `shape`(the passed in shape).
* input, an output or an internal variable) and `shape`(the passed in shape).
* - `type`: access readonly type information, including: `indices`(the type of indices), `value`(the type of value at
* runtime), `storage`(the type of value at storage) and `tensor`(the tensor type as represented in TensorView).
* - generate WGSL code for getting indices from offset. Use `offsetToIndices()` for WGSL code snippet to calculate
Expand Down Expand Up @@ -192,9 +193,9 @@ export interface IndicesHelper {
readonly name: string;

/**
* whether the helper is for an input or an output.
* whether the helper is for an input, an output or an internal variable.
*/
readonly usage: 'input'|'output';
readonly usage: 'input'|'output'|'internal';

/**
* the rank of the input or output.
Expand All @@ -210,11 +211,6 @@ export interface IndicesHelper {
* a string representing the variable name for the strides of the input or output.
*/
readonly strides: string;

/**
* representing variable with uniforms, but without binding.
*/
readonly uniformOnly: boolean;
}

const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => {
Expand Down Expand Up @@ -335,13 +331,13 @@ export const sumVector = (name: string, components: number) => {
* @param name - the name of the input or output.
* @param tensorType - the tensor type of the input or output.
* @param shapeOrRank - the tensor shape or the rank of the input or output.
* @param isInput - whether the helper is for an input or an output.
* @param usage - the usage of the indices helper.
* @param components - indicates the number of components of each element. 1 for scalar, 2 for vec2, 3 for vec3, 4 for
* vec4.
*/
const createIndicesHelper =
(name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, components: 1|2|3|4,
uniformOnly = false): IndicesHelper => {
(name: string, tensorType: number, shapeOrRank: number|readonly number[], usage: IndicesHelper['usage'],
components: 1|2|3|4): IndicesHelper => {
const useUniform = typeof shapeOrRank === 'number';
const rank = useUniform ? shapeOrRank : shapeOrRank.length;
const rankIdentity = [...new Array(rank).keys()];
Expand All @@ -363,7 +359,7 @@ const createIndicesHelper =
getByIndices: false,
};

const uniformPrefix = useUniform || uniformOnly ? 'uniforms.' : '';
const uniformPrefix = useUniform ? 'uniforms.' : '';
const shape = `${uniformPrefix}${name}_shape`;
const strides = `${uniformPrefix}${name}_strides`;
let o2iSnippet = '';
Expand Down Expand Up @@ -617,12 +613,11 @@ const createIndicesHelper =
getByOffset,
getByIndices,
// isVec4,
usage: isInput ? 'input' : 'output',
usage,
name,
strides,
shape,
rank,
uniformOnly
rank
};
};

Expand All @@ -636,8 +631,8 @@ const createIndicesHelper =
* @returns an IndicesHelper for the input.
*/
export const inputVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1, uniformOnly = false):
IndicesHelper => createIndicesHelper(name, type, shapeOrRank, true, components, uniformOnly);
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, 'input', components);

/**
* Create a IndicesHelper for an output.
Expand All @@ -650,7 +645,20 @@ export const inputVariable =
*/
export const outputVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, false, components);
createIndicesHelper(name, type, shapeOrRank, 'output', components);

/**
* Create a IndicesHelper for an internal variable.
*
* @param name - the name of the variable.
* @param type - the tensor type of the variable.
* @param shapeOrRank - the tensor shape or the rank of the variable.
* @param components - the number of components of the variable. available values are 1, 2, 3, 4. default is 1.
* @returns an IndicesHelper for the variable.
*/
export const internalVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, 'internal', components);

export type UniformsArrayType = Array<{name: string; type: string}>;

Expand Down Expand Up @@ -703,9 +711,27 @@ export interface ShaderHelper {

/**
* A helper function to register one uniform. Can be called multiple times to register multiple uniforms.
*
* @param name - the name of the uniform.
* @param type - the type of the uniform.
*/
registerUniform(name: string, type: string): ShaderHelper;
registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper;

/**
* A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms.
*
* @param uniforms - an array of uniforms. Each element of the array is an object with 2 properties: `name` and
* `type`.
*/
registerUniforms(uniforms: UniformsArrayType): ShaderHelper;

/**
* A helper function to register multiple internal variables. Can be called multiple times to register multiple
* internal variables.
*
* @param variables - an array of IndicesHelper for the variables.
*/
registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper;
}

class ShaderHelperImpl implements ShaderHelper {
Expand Down Expand Up @@ -740,8 +766,7 @@ class ShaderHelperImpl implements ShaderHelper {
`;
}

private declareVariable(variable: IndicesHelper, bindingIndex = -1): string {
this.indicesHelpers.push(variable);
private appendVariableUniforms(variable: IndicesHelper): void {
if (variable.rank !== 0) {
if (variable.shape.startsWith('uniforms.')) {
this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices});
Expand All @@ -750,24 +775,37 @@ class ShaderHelperImpl implements ShaderHelper {
this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices});
}
}
if (variable.uniformOnly) {
return '';
}

private declareVariable(variable: IndicesHelper, bindingIndex: number): string {
if (variable.usage === 'internal') {
throw new Error('cannot use internal variable with declareVariable(). use registerInternalVariables() instead.');
}
this.variables.push(variable);
this.appendVariableUniforms(variable);

const access = variable.usage === 'input' ? 'read' : 'read_write';
const storageType = variable.type.storage;
return `@group(0) @binding(${bindingIndex}) var<storage, ${access}> ${variable.name}: array<${storageType}>;`;
}

declareVariables(...variables: IndicesHelper[]): string {
return variables
.map(v => {
if (v.uniformOnly === true) {
return this.declareVariable(v);
} else {
return this.declareVariable(v, this.variableIndex++);
}
})
.join('\n');
return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n');
}

private registerInternalVariable(variable: IndicesHelper): void {
if (variable.usage !== 'internal') {
throw new Error(
'cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.');
}

this.internalVariables.push(variable);
this.appendVariableUniforms(variable);
}

registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper {
variables.forEach(v => this.registerInternalVariable(v));
return this;
}

registerUniform(name: string, type: string): ShaderHelper {
Expand All @@ -780,7 +818,8 @@ class ShaderHelperImpl implements ShaderHelper {
return this;
}

private indicesHelpers: IndicesHelper[] = [];
private internalVariables: IndicesHelper[] = [];
private variables: IndicesHelper[] = [];
private uniforms: UniformsArrayType = [];
private uniformDeclaration(): string {
if (this.uniforms.length === 0) {
Expand All @@ -802,7 +841,8 @@ class ShaderHelperImpl implements ShaderHelper {
* Get additional implementation that needs to be added to the shader source.
*/
get additionalImplementations(): string {
return this.uniformDeclaration() + this.indicesHelpers.map(i => i.impl()).join('\n');
return this.uniformDeclaration() + this.variables.map(i => i.impl()).join('\n') +
this.internalVariables.map(i => i.impl()).join('\n');
}
}

Expand Down

0 comments on commit 50e6235

Please sign in to comment.