From 70816001ccae305de24e27ab2219a8a17e1ca036 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Tue, 5 Dec 2023 09:19:53 -0800 Subject: [PATCH] [JS/Web] AddedUniforms in GatherElements. (#18670) ### Description Use Uniforms in GatherElements and clean-up ### Motivation and Context Improve performance --- .../wasm/jsep/webgpu/ops/gather-elements.ts | 58 +++++++++---------- 1 file changed, 26 insertions(+), 32 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index 9924a50e2ae6f..a945954adcaa4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -4,9 +4,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherElementsAttributes extends AttributeWithCacheKey { axis: number; @@ -32,65 +32,59 @@ const createGatherElementsProgramInfo = const inputShape = inputs[0].dims; const inputOutputDataType = inputs[0].dataType; const inputRank = inputShape.length; - const inputStrides = ShapeUtil.computeStrides(inputShape); - const inputSize = ShapeUtil.size(inputShape); const indicesShape = inputs[1].dims; const indicesDataType = inputs[1].dataType; - const indicesSize = ShapeUtil.size(indicesShape); - const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); const axisDimLimit = inputShape[axis]; const outputShape = indicesShape.slice(0); const outputSize = ShapeUtil.size(outputShape); - const input = inputVariable('input', inputOutputDataType, inputShape); - const indices = inputVariable('indices', indicesDataType, [indicesSize]); - const output = outputVariable('output', inputOutputDataType, outputShape); + const input = inputVariable('input', inputOutputDataType, inputRank); + const indices = inputVariable('indicesInput', indicesDataType, indicesShape.length); + const output = outputVariable('output', inputOutputDataType, outputShape.length); + + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; + programUniforms.push(...createTensorShapeVariables(inputShape)); + programUniforms.push(...createTensorShapeVariables(indicesShape)); + programUniforms.push(...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits // That assumption is safe as it's not possible to allocate >2gb buffer for input tensor // Input data will be treated as u32 or two u32 for 8-byte tensors const getShaderSource = (shaderHelper: ShaderHelper) => ` - const inputStrides = array(${inputStrides.map(i => `${i}u`).join(',')}); - ${shaderHelper.declareVariables(input, indices, output)} + ${ + shaderHelper.registerUniform('outputSize', 'u32') + .registerUniform('axisDimLimit', 'i32') + .registerUniform('axis', 'u32') + .declareVariables(input, indices, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let outputIndices = ${output.offsetToIndices('global_idx')}; var idx = ${indices.getByOffset('global_idx')}; if (idx < 0) { - idx = idx + ${axisDimLimit}; - } - - var srcOffset = u32(0); - - for (var i = 0; i < ${inputShape.length}; i++) { - if (i == ${axis}) { - srcOffset += u32(idx) * inputStrides[i]; - } else { - srcOffset += ${output.indicesGet('outputIndices', 'i')} * inputStrides[i]; - } - } - - // Should never hit this with valid values in indices - // This is a guard against malicious data in the indices input - if (srcOffset < 0 || srcOffset >= ${inputSize}) { - return; + idx = idx + uniforms.axisDimLimit; } + var inputIndices = ${input.type.indices}(outputIndices); + ${input.indicesSet('inputIndices', 'uniforms.axis', 'u32(idx)')}; + let value = ${input.getByIndices('inputIndices')}; - output[global_idx] = input[srcOffset]; + ${output.setByOffset('global_idx', 'value')}; }`; return { name: 'GatherElements', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, };