Skip to content

Commit

Permalink
[js/web] allow ShaderHelper to "use" helper
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Nov 20, 2023
1 parent cc54202 commit 3baa6f9
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,27 @@ export interface ShaderHelper {

/**
* A helper function to register one uniform. Can be called multiple times to register multiple uniforms.
*
* @param name - the name of the uniform.
* @param type - the type of the uniform.
*/
registerUniform(name: string, type: string): ShaderHelper;
registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper;

/**
* A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms.
*
* @param uniforms - an array of uniforms. Each element of the array is an object with 2 properties: `name` and
* `type`.
*/
registerUniforms(uniforms: UniformsArrayType): ShaderHelper;

/**
* A helper function to include a compatible helper for the purpose of generating specific WGSL code in front of the
* shader source. Can be called multiple times to include multiple helpers.
*
* @param helperLike - a helper-like object that has a function `impl` that returns a string of WGSL code.
*/
use(helperLike: {impl: () => string}): ShaderHelper;
}

class ShaderHelperImpl implements ShaderHelper {
Expand Down Expand Up @@ -735,7 +753,7 @@ class ShaderHelperImpl implements ShaderHelper {
}

private declareVariable(variable: IndicesHelper, bindingIndex: number): string {
this.indicesHelpers.push(variable);
this.variables.push(variable);
if (variable.rank !== 0) {
if (variable.shape.startsWith('uniforms.')) {
this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices});
Expand Down Expand Up @@ -763,7 +781,13 @@ class ShaderHelperImpl implements ShaderHelper {
return this;
}

private indicesHelpers: IndicesHelper[] = [];
use(helperLike: {impl: () => string}): ShaderHelper {
this.helpers.push(helperLike as IndicesHelper);
return this;
}

private variables: IndicesHelper[] = [];
private helpers: Array<{impl: () => string}> = [];
private uniforms: UniformsArrayType = [];
private uniformDeclaration(): string {
if (this.uniforms.length === 0) {
Expand All @@ -785,7 +809,8 @@ class ShaderHelperImpl implements ShaderHelper {
* Get additional implementation that needs to be added to the shader source.
*/
get additionalImplementations(): string {
return this.uniformDeclaration() + this.indicesHelpers.map(i => i.impl()).join('\n');
return this.uniformDeclaration() + this.variables.map(i => i.impl()).join('\n') +
this.helpers.map(i => i.impl()).join('\n');
}
}

Expand Down

0 comments on commit 3baa6f9

Please sign in to comment.