Skip to content

Commit

Permalink
[[JS/Web]Added uniform to Expand op. (#18558)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Added Uniforms to Expand operator kernel


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Improve performance
  • Loading branch information
satyajandhyala authored Nov 22, 2023
1 parent 1c555c5 commit 841f7ed
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
28 changes: 20 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/expand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 2) {
Expand Down Expand Up @@ -47,14 +47,18 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
const outputSize = ShapeUtil.size(outputShape);

const dataType = inputs[0].dataType;
const input = inputVariable('input', dataType, inputShape);
const output = outputVariable('output', dataType, outputShape);
const enableInputShapeUniform = enableShapesUniforms(inputShape.length);
const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape;
const input = inputVariable('input', dataType, inputShapeOrRank);
const enableOutputShapeUniform = enableShapesUniforms(outputShape.length);
const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape;
const output = outputVariable('output', dataType, outputShapeOrRank);

const getShaderSource = (shaderHelper: ShaderHelper) => `
const inputShape = ${input.indices(...inputShape)};
${shaderHelper.declareVariables(input, output)}
${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
let outputIndices = ${output.offsetToIndices('global_idx')};
var inputIndices: ${input.type.indices};
for (var i = 0; i < ${inputShape.length}; i++) {
Expand All @@ -68,13 +72,21 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
}
${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
}`;
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
if (enableInputShapeUniform) {
programUniforms.push(...createTensorShapeVariables(inputShape));
}
if (enableOutputShapeUniform) {
programUniforms.push(...createTensorShapeVariables(outputShape));
}
return {
name: 'Expand',
shaderCache: {hint: `${outputShape}`},
shaderCache: {hint: `${outputShape}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']},
getShaderSource,
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
})
};
};
Expand Down
29 changes: 29 additions & 0 deletions js/web/test/data/ops/expand.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,34 @@
]
}
]
},
{
"name": "Expand 5D - float32",
"operator": "Expand",
"attributes": [],
"cases": [
{
"name": "Expand 5 - float32",
"inputs": [
{
"data": [1],
"dims": [1, 1, 1, 1, 1],
"type": "float32"
},
{
"data": [1, 1, 1, 1, 6],
"dims": [5],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 1, 1, 1, 1, 1],
"dims": [1, 1, 1, 1, 6],
"type": "float32"
}
]
}
]
}
]

0 comments on commit 841f7ed

Please sign in to comment.