Skip to content

Commit

Permalink
[JS/Web] Added Unifroms support to unary ops. (microsoft#18223)
Browse files Browse the repository at this point in the history
### Description
Added uniforms support to unary ops.


### Motivation and Context
Improve performance
  • Loading branch information
satyajandhyala authored and kleiti committed Mar 22, 2024
1 parent 22b2494 commit bf47351
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ const createElementwiseProgramShader =
const output = outputVariable('outputData', outputDataType, [vecSize], 4);

return `
${shaderHelper.declareVariables(input, output)}
${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)}
${additionalImplementation ?? ''}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
let a = ${input.getByOffset('global_idx')};
${output.setByOffset('global_idx', expression)}
Expand All @@ -45,13 +45,16 @@ const createElementwiseProgramInfo =
(input: TensorView, name: string, funcCall: ElementwiseFunctionCall, additionalImplementation?: string,
cacheKey?: string, outputDataType: number = input.dataType): ProgramInfo => ({
name,
shaderCache: {hint: cacheKey},
shaderCache: {hint: cacheKey, inputDependencies: ['type']},
getShaderSource: shaderHelper => createElementwiseProgramShader(
shaderHelper, ShapeUtil.size(input.dims), input.dataType, outputDataType, funcCall, additionalImplementation),
getRunData: (inputTensors) => ({
outputs: [{dims: input.dims, dataType: outputDataType}],
dispatchGroup:
{x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}
{x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)},
programUniforms: [
{type: 'uint32', data: Math.ceil(ShapeUtil.size(input.dims) / 4)},
],
})
});

Expand Down

0 comments on commit bf47351

Please sign in to comment.