Skip to content

Commit

Permalink
[js/webgpu] Support uniforms for gather
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Nov 7, 2023
1 parent 630c877 commit f09349d
Showing 1 changed file with 40 additions and 11 deletions.
51 changes: 40 additions & 11 deletions js/web/lib/wasm/jsep/webgpu/ops/gather.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';

export interface GatherAttributes extends AttributeWithCacheKey {
axis: number;
Expand All @@ -31,20 +31,44 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
const axisDimLimit = inputShape[axis];
const outputSize = ShapeUtil.size(outputShape);

const data = inputVariable('data', inputs[0].dataType, inputs[0].dims);
const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims);
const output = outputVariable('output', inputs[0].dataType, outputShape);
const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length);
const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims;
const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length);
const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims;
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;

const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank);
const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank);
const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank);

const programUniforms: ProgramUniform[] =
[{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}];
if (enableInputShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
}
if (enableIndicesShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
}
if (enableOutputShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(outputShape));
}

const inputDependencies: ProgramInputTensorInfoDependency[] = [];
inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims');
inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims');

const calcDataIndices = (): string => {
const indicesRank = indicesShape.length;
let calcStr = `var indicesIndices = ${indices.type.indices}(0);`;
for (let i = 0; i < indicesRank; i++) {
calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${
outputShape.length > 1 ? `outputIndices[${axis + i}]` : 'outputIndices'};`;
outputShape.length > 1 ? `outputIndices[uniforms.axis + ${i}]` : 'outputIndices'};`;
}
calcStr += `
var idx = ${indices.getByIndices('indicesIndices')};
if (idx < 0) {
idx = idx + ${axisDimLimit};
idx = idx + uniforms.axisDimLimit;
}
var dataIndices = ${data.type.indices}(0);
`;
Expand All @@ -62,22 +86,27 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
};

const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(data, indices, output)}
${
shaderHelper.registerUniform('outputSize', 'u32')
.registerUniform('axisDimLimit', 'i32')
.registerUniform('axis', 'u32')
.declareVariables(data, indices, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
let outputIndices = ${output.offsetToIndices('global_idx')};
${calcDataIndices()};
let value = ${data.getByIndices('dataIndices')};
${output.setByOffset('global_idx', 'value')};
}`;
return {
name: 'Gather',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {hint: attributes.cacheKey, inputDependencies},
getRunData: () => ({
outputs: [
{dims: outputShape, dataType: inputs[0].dataType},
],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource,
};
Expand Down

0 comments on commit f09349d

Please sign in to comment.