From 841f7ed3e0c393b22b1631c090c61b20fc62f876 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Wed, 22 Nov 2023 14:14:24 -0800 Subject: [PATCH] [[JS/Web]Added uniform to Expand op. (#18558) ### Description Added Uniforms to Expand operator kernel ### Motivation and Context Improve performance --- js/web/lib/wasm/jsep/webgpu/ops/expand.ts | 28 +++++++++++++++------- js/web/test/data/ops/expand.jsonc | 29 +++++++++++++++++++++++ 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 5680af4787b6a..d998013352d77 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -3,9 +3,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -47,14 +47,18 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const outputSize = ShapeUtil.size(outputShape); const dataType = inputs[0].dataType; - const input = inputVariable('input', dataType, inputShape); - const output = outputVariable('output', dataType, outputShape); + const enableInputShapeUniform = enableShapesUniforms(inputShape.length); + const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape; + const input = inputVariable('input', dataType, inputShapeOrRank); + const enableOutputShapeUniform = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape; + const output = outputVariable('output', dataType, outputShapeOrRank); const getShaderSource = (shaderHelper: ShaderHelper) => ` const inputShape = ${input.indices(...inputShape)}; - ${shaderHelper.declareVariables(input, output)} + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} let outputIndices = ${output.offsetToIndices('global_idx')}; var inputIndices: ${input.type.indices}; for (var i = 0; i < ${inputShape.length}; i++) { @@ -68,13 +72,21 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => } ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} }`; + const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; + if (enableInputShapeUniform) { + programUniforms.push(...createTensorShapeVariables(inputShape)); + } + if (enableOutputShapeUniform) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } return { name: 'Expand', - shaderCache: {hint: `${outputShape}`}, + shaderCache: {hint: `${outputShape}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']}, getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }) }; }; diff --git a/js/web/test/data/ops/expand.jsonc b/js/web/test/data/ops/expand.jsonc index 460122b4e085c..35888e2fc3709 100644 --- a/js/web/test/data/ops/expand.jsonc +++ b/js/web/test/data/ops/expand.jsonc @@ -85,5 +85,34 @@ ] } ] + }, + { + "name": "Expand 5D - float32", + "operator": "Expand", + "attributes": [], + "cases": [ + { + "name": "Expand 5 - float32", + "inputs": [ + { + "data": [1], + "dims": [1, 1, 1, 1, 1], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 6], + "dims": [5], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 1, 1, 1, 1, 1], + "dims": [1, 1, 1, 1, 6], + "type": "float32" + } + ] + } + ] } ]