diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index fdcd64abfe4e7..5d6d6debadb9a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -4,9 +4,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -31,20 +31,44 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const axisDimLimit = inputShape[axis]; const outputSize = ShapeUtil.size(outputShape); - const data = inputVariable('data', inputs[0].dataType, inputs[0].dims); - const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims); - const output = outputVariable('output', inputs[0].dataType, outputShape); + const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length); + const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims; + const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length); + const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims; + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + + const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank); + const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); + const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank); + + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; + if (enableInputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + } + if (enableIndicesShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + } + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } + + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims'); + inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims'); + const calcDataIndices = (): string => { const indicesRank = indicesShape.length; let calcStr = `var indicesIndices = ${indices.type.indices}(0);`; for (let i = 0; i < indicesRank; i++) { calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${ - outputShape.length > 1 ? `outputIndices[${axis + i}]` : 'outputIndices'};`; + outputShape.length > 1 ? `outputIndices[uniforms.axis + ${i}]` : 'outputIndices'};`; } calcStr += ` var idx = ${indices.getByIndices('indicesIndices')}; if (idx < 0) { - idx = idx + ${axisDimLimit}; + idx = idx + uniforms.axisDimLimit; } var dataIndices = ${data.type.indices}(0); `; @@ -62,9 +86,13 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath }; const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(data, indices, output)} + ${ + shaderHelper.registerUniform('outputSize', 'u32') + .registerUniform('axisDimLimit', 'i32') + .registerUniform('axis', 'u32') + .declareVariables(data, indices, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let outputIndices = ${output.offsetToIndices('global_idx')}; ${calcDataIndices()}; let value = ${data.getByIndices('dataIndices')}; @@ -72,12 +100,13 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath }`; return { name: 'Gather', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: attributes.cacheKey, inputDependencies}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: inputs[0].dataType}, ], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, };