From 5302cff6d4a14ee9def4e692bd7d874ead74a11f Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 26 Jan 2024 10:39:44 +0800 Subject: [PATCH] [js/webgpu] Remove enableShapesUniforms --- .../ops/3rd-party/matmul_packed_webgpu.ts | 12 +++--- js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 37 +++++++----------- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 3 -- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 26 ++++--------- js/web/lib/wasm/jsep/webgpu/ops/einsum.ts | 31 +++++---------- js/web/lib/wasm/jsep/webgpu/ops/expand.ts | 25 ++++-------- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 39 ++++++------------- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 25 +++++------- 9 files changed, 68 insertions(+), 134 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index ee71110245252..5881c055ef135 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -443,9 +443,9 @@ export const createMatmulProgramInfo = const components = isVec4 ? 4 : 1; const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; - const aShapeOrRank = aShapeTemp.length; + const aRank = aShapeTemp.length; const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; - const bShapeOrRank = bShapeTemp.length; + const bRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; @@ -467,12 +467,12 @@ export const createMatmulProgramInfo = programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); const getShaderSource = (shaderHelper: ShaderHelper) => { - const batchShapeOrRank = outerDims.length; - const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); + const batchRank = outerDims.length; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchRank, 1); const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); - const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const A = inputVariable('a', inputs[0].dataType, aRank, components); + const B = inputVariable('b', inputs[1].dataType, bRank, components); const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); const inputVariables = [A, B]; if (hasBias) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts index 00a6ca75b34fa..159b971636765 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -8,7 +8,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; export interface BatchNormAttributes extends AttributeWithCacheKey { readonly epsilon: number; @@ -61,7 +61,7 @@ const createBatchNormInferenceProgramInfo = const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; const outputSize = ShapeUtil.size(yShape) / components; // Only support uniforms for opset version >= 9 (spatial = true). - const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial; + const useShapesUniforms = spatial; const shapeOrRank = useShapesUniforms ? yShape.length : yShape; const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index c033c0ba05356..8e144a36dc1b0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; type BuiltinFunctionName = string; type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; @@ -18,8 +18,7 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ const createBinaryOpProgramShader = (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall, - typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean, - additionalImplementation?: string) => { + typeA: number, typeB: number, typeOutput: number, additionalImplementation?: string) => { let expressionScalar: BinaryCustomExpression; let expressionVector: BinaryCustomExpression; if (typeof funcCall === 'string') { @@ -31,12 +30,9 @@ const createBinaryOpProgramShader = expressionVector = funcCall.vector; } - const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA; - const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB; - const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput; - const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4); - const a = inputVariable('aData', typeA, inputAShapeOrRank, 4); - const b = inputVariable('bData', typeB, inputBShapeOrRank, 4); + const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4); + const a = inputVariable('aData', typeA, dimsA.length, 4); + const b = inputVariable('bData', typeB, dimsB.length, 4); let assignment: string; if (vectorize) { @@ -169,30 +165,25 @@ const createBinaryOpProgramInfo = vectorize = true; } cacheKeyAux.push(vectorize); - const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) && - enableShapesUniforms(outputShape.length); + return { name, shaderCache: { hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'), - inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'], + inputDependencies: ['rank', 'rank'], }, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, - a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation), + a.dataType, b.dataType, outputDataType, additionalImplementation), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, - programUniforms: useShapesUniforms ? - [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ...createTensorShapeVariables(a.dims), - ...createTensorShapeVariables(b.dims), - ...createTensorShapeVariables(outputShape), - ] : - [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ], + programUniforms: [ + {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + ...createTensorShapeVariables(a.dims), + ...createTensorShapeVariables(b.dims), + ...createTensorShapeVariables(outputShape), + ], }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index bc3265be955f0..2ed4fdcca0071 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -908,6 +908,3 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly } return dims; }; - -// TODO: remove this when all related uses have been removed. -export const enableShapesUniforms = (_rank: number): boolean => true; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 43cc4a4c080bd..daa326b1a34e2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -94,32 +94,22 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P let previousSum = 0; const inputDependencies: ProgramInputTensorInfoDependency[] = []; - const inputShapeOrRanks = []; - const enableInputShapesUniforms = []; + const inputRanks = []; const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; for (let i = 0; i < inputs.length; ++i) { previousSum += inputs[i].dims[adjustedAxis]; sizeInConcatAxis[i] = previousSum; - enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length)); - inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims); - inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]); - inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims'); + inputRanks.push(inputs[i].dims.length); + inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); + inputDependencies.push('rank'); programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); } for (let i = 0; i < inputs.length; ++i) { - if (enableInputShapesUniforms[i]) { - programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); - } - } - - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - const output = outputVariable('output', dataType, outputShapeOrRank); - + const output = outputVariable('output', dataType, outputShape.length); const indicesAxis = output.indicesGet('indices', adjustedAxis); const sizeInConcatAxisStr = Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index 4db7c04ad67be..9e1f58bbfa127 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -6,8 +6,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; - +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; export interface EinsumAttributes extends AttributeWithCacheKey { readonly equation: string; @@ -181,14 +180,12 @@ class EinsumEquation { const appendMax = (name: string): string => name + '_max'; const createEinsumProgramInfo = - (enableInputShapesUniforms: readonly boolean[], inputShapes: Array, dataType: number, - einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => { - const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims); - const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank)); + (inputShapes: Array, dataType: number, einsumEquation: EinsumEquation, + outputShape: readonly number[]): ProgramInfo => { + const ranks = inputShapes.map((dims) => dims.length); + const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank)); const outputSize = ShapeUtil.size(outputShape); - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - const output = outputVariable('output', dataType, outputShapeOrRank); + const output = outputVariable('output', dataType, outputShape.length); const uniformsSymbols = [...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol)); const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -269,10 +266,7 @@ const createEinsumProgramInfo = }; return { name: 'Einsum', - shaderCache: { - hint: einsumEquation.equation, - inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims') - }, + shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')}, getRunData: () => { // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The // filter is added to make sure that dimValue is never 0. @@ -281,12 +275,9 @@ const createEinsumProgramInfo = .map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); programUniformsInit.push({type: 'uint32', data: outputSize}); const programUniforms: ProgramUniform[] = - inputShapes.filter((_, index) => enableInputShapesUniforms[index]) - .map((dims, _) => [...createTensorShapeVariables(dims)]) + inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)]) .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } + programUniforms.push(...createTensorShapeVariables(outputShape)); return ({ outputs: [{dims: outputShape, dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, @@ -299,11 +290,9 @@ const createEinsumProgramInfo = export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => { const einsumEquation = new EinsumEquation(context.inputs, attributes.equation); - const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length)); const outputShape = einsumEquation.outputDims; const inputShapes = context.inputs.map((input, _) => input.dims); - context.compute(createEinsumProgramInfo( - enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); + context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); }; export const parseEinsumAttributes = (attributes: Record): EinsumAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 035d89755c7d7..dd18bd23a5912 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -49,15 +49,9 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const components = dataType === DataType.bool ? 4 : 1; const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); - const enableInputShapeUniform = enableShapesUniforms(inputShape.length); - const enableOutputShapeUniform = enableShapesUniforms(outputShape.length); - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape; - const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape; - const input = inputVariable('input', dataType, inputShapeOrRank, components); - const output = outputVariable('output', dataType, outputShapeOrRank, components); + const input = inputVariable('input', dataType, inputShape.length, components); + const output = outputVariable('output', dataType, outputShape.length, components); let assignment: string; if (dataType === DataType.bool) { const singleAssignment = (resStr: string, x: number, typeCast = '') => ` @@ -90,16 +84,13 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => ${assignment}`; }; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; - if (enableInputShapeUniform) { - programUniforms.push(...createTensorShapeVariables(inputShape)); - } - if (enableOutputShapeUniform) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), + ...createTensorShapeVariables(outputShape) + ]; return { name: 'Expand', - shaderCache: {hint: `${outputShape.length}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']}, + shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']}, getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 469249f92ff28..e2a62c6655c72 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -33,33 +33,16 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const components = inputs[0].dataType === DataType.bool ? 4 : 1; const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); - const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length); - const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims; - const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length); - const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims; - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; - if (enableInputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - } - if (enableIndicesShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); - } - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } - - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims'); - inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims'); + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}, + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims), + ...createTensorShapeVariables(outputShape) + ]; const getShaderSource = (shaderHelper: ShaderHelper) => { - const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank, components); - const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); - const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank, components); + const data = inputVariable('data', inputs[0].dataType, inputs[0].dims.length, components); + const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length); + const output = outputVariable('output', inputs[0].dataType, outputShape.length, components); const calcDataIndices = (x: number|string): string => { const indicesRank = indicesShape.length; @@ -127,7 +110,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath }; return { name: 'Gather', - shaderCache: {hint: attributes.cacheKey, inputDependencies}, + shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank', 'rank']}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: inputs[0].dataType}, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index c4d43e9f466f5..ab9a9ac8dd1f0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; @@ -39,12 +39,9 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu const inputDataType = inputTensor.dataType; const inputRank = inputTensor.dims.length; const perm = getAdjustedPerm(inputRank, permAttr); - const useShapesUniforms = enableShapesUniforms(inputRank); const outputShape = getOutputShape(inputTensor.dims, perm); - const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape; - const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims; - const output = outputVariable('output', inputDataType, outShapeOrRank); - const input = inputVariable('a', inputDataType, inShapeOrRank); + const output = outputVariable('output', inputDataType, outputShape.length); + const input = inputVariable('a', inputDataType, inputRank); const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} @@ -61,21 +58,17 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu }`; return { name: 'Transpose', - shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']}, + shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, getRunData: (inputs) => { const outputSize = ShapeUtil.size(outputShape); return { outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: useShapesUniforms ? - [ - {type: 'uint32', data: outputSize}, - ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape), - ] : - [ - {type: 'uint32', data: outputSize}, - ], + programUniforms: [ + {type: 'uint32', data: outputSize}, + ...createTensorShapeVariables(inputs[0].dims), + ...createTensorShapeVariables(outputShape), + ], }; }, getShaderSource,