From c3f6e78efc4109f42cd5b7ed1d56a920bfefb336 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 21 Nov 2023 14:50:30 -0800 Subject: [PATCH] update to apply revise as comments --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 66 +++++++++++++++++++---- 1 file changed, 56 insertions(+), 10 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 5ef8fa5bcbf37..657ac09c3dbc5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -58,10 +58,11 @@ interface IndicesHelperTypes { * create an instance of an indices helper: * - `inputVariable()`: create an indices helper instance for an input. * - `outputVariable()`: create an indices helper instance for an output. + * - `internalVariable()`: create an indices helper instance for an internal variable. * * An indices helper instance contains helper functions for the following operations: * - access readonly basic information, including: `name`(the name of the input or output), `usage`(whether it's an - * input or an output) and `shape`(the passed in shape). + * input, an output or an internal variable) and `shape`(the passed in shape). * - `type`: access readonly type information, including: `indices`(the type of indices), `value`(the type of value at * runtime), `storage`(the type of value at storage) and `tensor`(the tensor type as represented in TensorView). * - generate WGSL code for getting indices from offset. Use `offsetToIndices()` for WGSL code snippet to calculate @@ -192,9 +193,9 @@ export interface IndicesHelper { readonly name: string; /** - * whether the helper is for an input or an output. + * whether the helper is for an input, an output or an internal variable. */ - readonly usage: 'input'|'output'; + readonly usage: 'input'|'output'|'internal'; /** * the rank of the input or output. @@ -330,12 +331,12 @@ export const sumVector = (name: string, components: number) => { * @param name - the name of the input or output. * @param tensorType - the tensor type of the input or output. * @param shapeOrRank - the tensor shape or the rank of the input or output. - * @param isInput - whether the helper is for an input or an output. + * @param usage - the usage of the indices helper. * @param components - indicates the number of components of each element. 1 for scalar, 2 for vec2, 3 for vec3, 4 for * vec4. */ const createIndicesHelper = - (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, + (name: string, tensorType: number, shapeOrRank: number|readonly number[], usage: IndicesHelper['usage'], components: 1|2|3|4): IndicesHelper => { const useUniform = typeof shapeOrRank === 'number'; const rank = useUniform ? shapeOrRank : shapeOrRank.length; @@ -612,7 +613,7 @@ const createIndicesHelper = getByOffset, getByIndices, // isVec4, - usage: isInput ? 'input' : 'output', + usage, name, strides, shape, @@ -631,7 +632,7 @@ const createIndicesHelper = */ export const inputVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shapeOrRank, true, components); + createIndicesHelper(name, type, shapeOrRank, 'input', components); /** * Create a IndicesHelper for an output. @@ -644,7 +645,20 @@ export const inputVariable = */ export const outputVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shapeOrRank, false, components); + createIndicesHelper(name, type, shapeOrRank, 'output', components); + +/** + * Create a IndicesHelper for an internal variable. + * + * @param name - the name of the variable. + * @param type - the tensor type of the variable. + * @param shapeOrRank - the tensor shape or the rank of the variable. + * @param components - the number of components of the variable. available values are 1, 2, 3, 4. default is 1. + * @returns an IndicesHelper for the variable. + */ +export const internalVariable = + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, 'internal', components); export type UniformsArrayType = Array<{name: string; type: string}>; @@ -711,6 +725,14 @@ export interface ShaderHelper { */ registerUniforms(uniforms: UniformsArrayType): ShaderHelper; + /** + * A helper function to register multiple internal variables. Can be called multiple times to register multiple + * internal variables. + * + * @param variables - an array of IndicesHelper for the variables. + */ + registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper; + /** * A helper function to include a compatible helper for the purpose of generating specific WGSL code in front of the * shader source. Can be called multiple times to include multiple helpers. @@ -752,8 +774,7 @@ class ShaderHelperImpl implements ShaderHelper { `; } - private declareVariable(variable: IndicesHelper, bindingIndex: number): string { - this.variables.push(variable); + private appendVariableUniforms(variable: IndicesHelper): void { if (variable.rank !== 0) { if (variable.shape.startsWith('uniforms.')) { this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices}); @@ -762,6 +783,15 @@ class ShaderHelperImpl implements ShaderHelper { this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices}); } } + } + + private declareVariable(variable: IndicesHelper, bindingIndex: number): string { + if (variable.usage === 'internal') { + throw new Error('cannot use internal variable with declareVariable(). use registerInternalVariables() instead.'); + } + this.variables.push(variable); + this.appendVariableUniforms(variable); + const access = variable.usage === 'input' ? 'read' : 'read_write'; const storageType = variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; @@ -771,6 +801,21 @@ class ShaderHelperImpl implements ShaderHelper { return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n'); } + private registerInternalVariable(variable: IndicesHelper): void { + if (variable.usage !== 'internal') { + throw new Error( + 'cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.'); + } + + this.internalVariables.push(variable); + this.appendVariableUniforms(variable); + } + + registerInternalVariables(...variables: IndicesHelper[]): ShaderHelper { + variables.forEach(v => this.registerInternalVariable(v)); + return this; + } + registerUniform(name: string, type: string): ShaderHelper { this.uniforms.push({name, type}); return this; @@ -786,6 +831,7 @@ class ShaderHelperImpl implements ShaderHelper { return this; } + private internalVariables: IndicesHelper[] = []; private variables: IndicesHelper[] = []; private helpers: Array<{impl: () => string}> = []; private uniforms: UniformsArrayType = [];