Skip to content

Commit

Permalink
[JS/Web] AddedUniforms in GatherElements. (#18670)
Browse files Browse the repository at this point in the history
### Description
Use Uniforms in GatherElements and clean-up



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Improve performance
  • Loading branch information
satyajandhyala authored Dec 5, 2023
1 parent f949e05 commit 7081600
Showing 1 changed file with 26 additions and 32 deletions.
58 changes: 26 additions & 32 deletions js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';

export interface GatherElementsAttributes extends AttributeWithCacheKey {
axis: number;
Expand All @@ -32,65 +32,59 @@ const createGatherElementsProgramInfo =
const inputShape = inputs[0].dims;
const inputOutputDataType = inputs[0].dataType;
const inputRank = inputShape.length;
const inputStrides = ShapeUtil.computeStrides(inputShape);
const inputSize = ShapeUtil.size(inputShape);

const indicesShape = inputs[1].dims;
const indicesDataType = inputs[1].dataType;
const indicesSize = ShapeUtil.size(indicesShape);

const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
const axisDimLimit = inputShape[axis];

const outputShape = indicesShape.slice(0);
const outputSize = ShapeUtil.size(outputShape);

const input = inputVariable('input', inputOutputDataType, inputShape);
const indices = inputVariable('indices', indicesDataType, [indicesSize]);
const output = outputVariable('output', inputOutputDataType, outputShape);
const input = inputVariable('input', inputOutputDataType, inputRank);
const indices = inputVariable('indicesInput', indicesDataType, indicesShape.length);
const output = outputVariable('output', inputOutputDataType, outputShape.length);


const programUniforms: ProgramUniform[] =
[{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}];
programUniforms.push(...createTensorShapeVariables(inputShape));
programUniforms.push(...createTensorShapeVariables(indicesShape));
programUniforms.push(...createTensorShapeVariables(outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];

// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
// That assumption is safe as it's not possible to allocate >2gb buffer for input tensor
// Input data will be treated as u32 or two u32 for 8-byte tensors
const getShaderSource = (shaderHelper: ShaderHelper) => `
const inputStrides = array<u32, ${inputStrides.length}>(${inputStrides.map(i => `${i}u`).join(',')});
${shaderHelper.declareVariables(input, indices, output)}
${
shaderHelper.registerUniform('outputSize', 'u32')
.registerUniform('axisDimLimit', 'i32')
.registerUniform('axis', 'u32')
.declareVariables(input, indices, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
let outputIndices = ${output.offsetToIndices('global_idx')};
var idx = ${indices.getByOffset('global_idx')};
if (idx < 0) {
idx = idx + ${axisDimLimit};
}
var srcOffset = u32(0);
for (var i = 0; i < ${inputShape.length}; i++) {
if (i == ${axis}) {
srcOffset += u32(idx) * inputStrides[i];
} else {
srcOffset += ${output.indicesGet('outputIndices', 'i')} * inputStrides[i];
}
}
// Should never hit this with valid values in indices
// This is a guard against malicious data in the indices input
if (srcOffset < 0 || srcOffset >= ${inputSize}) {
return;
idx = idx + uniforms.axisDimLimit;
}
var inputIndices = ${input.type.indices}(outputIndices);
${input.indicesSet('inputIndices', 'uniforms.axis', 'u32(idx)')};
let value = ${input.getByIndices('inputIndices')};
output[global_idx] = input[srcOffset];
${output.setByOffset('global_idx', 'value')};
}`;

return {
name: 'GatherElements',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource,
};
Expand Down

0 comments on commit 7081600

Please sign in to comment.