From 3baa6f956637d0dfeb2ceeffabe90a9952485cc4 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 20 Nov 2023 15:55:24 -0800 Subject: [PATCH] [js/web] allow ShaderHelper to "use" helper --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 33 ++++++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 014d9d02f6f10..17fc4aa0a390f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -697,9 +697,27 @@ export interface ShaderHelper { /** * A helper function to register one uniform. Can be called multiple times to register multiple uniforms. + * + * @param name - the name of the uniform. + * @param type - the type of the uniform. */ registerUniform(name: string, type: string): ShaderHelper; - registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper; + + /** + * A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms. + * + * @param uniforms - an array of uniforms. Each element of the array is an object with 2 properties: `name` and + * `type`. + */ + registerUniforms(uniforms: UniformsArrayType): ShaderHelper; + + /** + * A helper function to include a compatible helper for the purpose of generating specific WGSL code in front of the + * shader source. Can be called multiple times to include multiple helpers. + * + * @param helperLike - a helper-like object that has a function `impl` that returns a string of WGSL code. + */ + use(helperLike: {impl: () => string}): ShaderHelper; } class ShaderHelperImpl implements ShaderHelper { @@ -735,7 +753,7 @@ class ShaderHelperImpl implements ShaderHelper { } private declareVariable(variable: IndicesHelper, bindingIndex: number): string { - this.indicesHelpers.push(variable); + this.variables.push(variable); if (variable.rank !== 0) { if (variable.shape.startsWith('uniforms.')) { this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices}); @@ -763,7 +781,13 @@ class ShaderHelperImpl implements ShaderHelper { return this; } - private indicesHelpers: IndicesHelper[] = []; + use(helperLike: {impl: () => string}): ShaderHelper { + this.helpers.push(helperLike as IndicesHelper); + return this; + } + + private variables: IndicesHelper[] = []; + private helpers: Array<{impl: () => string}> = []; private uniforms: UniformsArrayType = []; private uniformDeclaration(): string { if (this.uniforms.length === 0) { @@ -785,7 +809,8 @@ class ShaderHelperImpl implements ShaderHelper { * Get additional implementation that needs to be added to the shader source. */ get additionalImplementations(): string { - return this.uniformDeclaration() + this.indicesHelpers.map(i => i.impl()).join('\n'); + return this.uniformDeclaration() + this.variables.map(i => i.impl()).join('\n') + + this.helpers.map(i => i.impl()).join('\n'); } }