Skip to content

Commit

Permalink
update to apply revise as comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Nov 21, 2023
1 parent 8b2d0a4 commit c3f6e78
Showing 1 changed file with 56 additions and 10 deletions.
66 changes: 56 additions & 10 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ interface IndicesHelperTypes {
* create an instance of an indices helper:
* - `inputVariable()`: create an indices helper instance for an input.
* - `outputVariable()`: create an indices helper instance for an output.
* - `internalVariable()`: create an indices helper instance for an internal variable.
*
* An indices helper instance contains helper functions for the following operations:
* - access readonly basic information, including: `name`(the name of the input or output), `usage`(whether it's an
* input or an output) and `shape`(the passed in shape).
* input, an output or an internal variable) and `shape`(the passed in shape).
* - `type`: access readonly type information, including: `indices`(the type of indices), `value`(the type of value at
* runtime), `storage`(the type of value at storage) and `tensor`(the tensor type as represented in TensorView).
* - generate WGSL code for getting indices from offset. Use `offsetToIndices()` for WGSL code snippet to calculate
Expand Down Expand Up @@ -192,9 +193,9 @@ export interface IndicesHelper {
readonly name: string;

/**
* whether the helper is for an input or an output.
* whether the helper is for an input, an output or an internal variable.
*/
readonly usage: 'input'|'output';
readonly usage: 'input'|'output'|'internal';

/**
* the rank of the input or output.
Expand Down Expand Up @@ -330,12 +331,12 @@ export const sumVector = (name: string, components: number) => {
* @param name - the name of the input or output.
* @param tensorType - the tensor type of the input or output.
* @param shapeOrRank - the tensor shape or the rank of the input or output.
* @param isInput - whether the helper is for an input or an output.
* @param usage - the usage of the indices helper.
* @param components - indicates the number of components of each element. 1 for scalar, 2 for vec2, 3 for vec3, 4 for
* vec4.
*/
const createIndicesHelper =
(name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean,
(name: string, tensorType: number, shapeOrRank: number|readonly number[], usage: IndicesHelper['usage'],
components: 1|2|3|4): IndicesHelper => {
const useUniform = typeof shapeOrRank === 'number';
const rank = useUniform ? shapeOrRank : shapeOrRank.length;
Expand Down Expand Up @@ -612,7 +613,7 @@ const createIndicesHelper =
getByOffset,
getByIndices,
// isVec4,
usage: isInput ? 'input' : 'output',
usage,
name,
strides,
shape,
Expand All @@ -631,7 +632,7 @@ const createIndicesHelper =
*/
export const inputVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, true, components);
createIndicesHelper(name, type, shapeOrRank, 'input', components);

/**
* Create a IndicesHelper for an output.
Expand All @@ -644,7 +645,20 @@ export const inputVariable =
*/
export const outputVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, false, components);
createIndicesHelper(name, type, shapeOrRank, 'output', components);

/**
* Create a IndicesHelper for an internal variable.
*
* @param name - the name of the variable.
* @param type - the tensor type of the variable.
* @param shapeOrRank - the tensor shape or the rank of the variable.
* @param components - the number of components of the variable. available values are 1, 2, 3, 4. default is 1.
* @returns an IndicesHelper for the variable.
*/
export const internalVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, 'internal', components);

export type UniformsArrayType = Array<{name: string; type: string}>;

Expand Down Expand Up @@ -711,6 +725,14 @@ export interface ShaderHelper {
*/
registerUniforms(uniforms: UniformsArrayType): ShaderHelper;

/**
* A helper function to register multiple internal variables. Can be called multiple times to register multiple
* internal variables.
*
* @param variables - an array of IndicesHelper for the variables.
*/
registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper;

/**
* A helper function to include a compatible helper for the purpose of generating specific WGSL code in front of the
* shader source. Can be called multiple times to include multiple helpers.
Expand Down Expand Up @@ -752,8 +774,7 @@ class ShaderHelperImpl implements ShaderHelper {
`;
}

private declareVariable(variable: IndicesHelper, bindingIndex: number): string {
this.variables.push(variable);
private appendVariableUniforms(variable: IndicesHelper): void {
if (variable.rank !== 0) {
if (variable.shape.startsWith('uniforms.')) {
this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices});
Expand All @@ -762,6 +783,15 @@ class ShaderHelperImpl implements ShaderHelper {
this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices});
}
}
}

private declareVariable(variable: IndicesHelper, bindingIndex: number): string {
if (variable.usage === 'internal') {
throw new Error('cannot use internal variable with declareVariable(). use registerInternalVariables() instead.');
}
this.variables.push(variable);
this.appendVariableUniforms(variable);

const access = variable.usage === 'input' ? 'read' : 'read_write';
const storageType = variable.type.storage;
return `@group(0) @binding(${bindingIndex}) var<storage, ${access}> ${variable.name}: array<${storageType}>;`;
Expand All @@ -771,6 +801,21 @@ class ShaderHelperImpl implements ShaderHelper {
return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n');
}

private registerInternalVariable(variable: IndicesHelper): void {
if (variable.usage !== 'internal') {
throw new Error(
'cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.');
}

this.internalVariables.push(variable);
this.appendVariableUniforms(variable);
}

registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper {
variables.forEach(v => this.registerInternalVariable(v));
return this;
}

registerUniform(name: string, type: string): ShaderHelper {
this.uniforms.push({name, type});
return this;
Expand All @@ -786,6 +831,7 @@ class ShaderHelperImpl implements ShaderHelper {
return this;
}

private internalVariables: IndicesHelper[] = [];
private variables: IndicesHelper[] = [];
private helpers: Array<{impl: () => string}> = [];
private uniforms: UniformsArrayType = [];
Expand Down

0 comments on commit c3f6e78

Please sign in to comment.