diff --git a/web/lib/wasm/jsep/webgpu/ops/common.ts b/web/lib/wasm/jsep/webgpu/ops/common.ts index 0a64d1ad1792a..1d3fc78fe368a 100644 --- a/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -803,3 +803,6 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly } return dims; }; + +// TODO: remove this limitation once >4D dims are supported by uniform. +export const enableShapesUniforms = (rank: number): boolean => rank <= 4; diff --git a/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index d241b8b92a669..e880afe09a5d8 100644 --- a/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -232,7 +232,7 @@ const convTranspose2d = // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposePerm), + createTransposeProgramInfo(inputs[1], weightTransposePerm), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; diff --git a/web/lib/wasm/jsep/webgpu/ops/conv.ts b/web/lib/wasm/jsep/webgpu/ops/conv.ts index b323a36cee5c8..c7ea0cffe51c3 100644 --- a/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -168,7 +168,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut if (isChannelsLast) { const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute), + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; @@ -208,7 +208,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute), + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; diff --git a/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/web/lib/wasm/jsep/webgpu/ops/transpose.ts index fe556a7fd8552..c4d43e9f466f5 100644 --- a/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; @@ -35,13 +35,18 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou return reverseFunc.join('\n'); }; -export const createTransposeProgramInfo = - (inputDataType: number, inputRank: number, permAttr: number[]): ProgramInfo => { - const perm = getAdjustedPerm(inputRank, permAttr); - const output = outputVariable('output', inputDataType, (permAttr && permAttr.length) || inputRank); - const input = inputVariable('a', inputDataType, inputRank); +export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => { + const inputDataType = inputTensor.dataType; + const inputRank = inputTensor.dims.length; + const perm = getAdjustedPerm(inputRank, permAttr); + const useShapesUniforms = enableShapesUniforms(inputRank); + const outputShape = getOutputShape(inputTensor.dims, perm); + const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape; + const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims; + const output = outputVariable('output', inputDataType, outShapeOrRank); + const input = inputVariable('a', inputDataType, inShapeOrRank); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${permFunctionBody(perm, inputRank, input, output)} @@ -54,30 +59,32 @@ export const createTransposeProgramInfo = ${output.setByOffset('global_idx', input.getByIndices('aIndices'))} }`; + return { + name: 'Transpose', + shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']}, + getRunData: (inputs) => { + const outputSize = ShapeUtil.size(outputShape); return { - name: 'Transpose', - shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, - getRunData: (inputs) => { - const outputShape = getOutputShape(inputs[0].dims, perm); - const outputSize = ShapeUtil.size(outputShape); - return { - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: useShapesUniforms ? + [ {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape), + ] : + [ + {type: 'uint32', data: outputSize}, ], - }; - }, - getShaderSource, }; - }; + }, + getShaderSource, + }; +}; export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => { validateInputs(context.inputs); - context.compute( - createTransposeProgramInfo(context.inputs[0].dataType, context.inputs[0].dims.length, attributes.perm)); + context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm)); }; export const parseTransposeAttributes = (attributes: Record): TransposeAttributes => diff --git a/web/test/data/ops/transpose.jsonc b/web/test/data/ops/transpose.jsonc index 285d14018e74d..e1edfa7e41513 100644 --- a/web/test/data/ops/transpose.jsonc +++ b/web/test/data/ops/transpose.jsonc @@ -166,5 +166,29 @@ ] } ] + }, + { + "name": "Transpose 5D - perms:[4, 3, 1, 0, 2]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [4, 3, 1, 0, 2], "type": "ints" }], + "cases": [ + { + "name": "T[3, 1, 2, 1, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [3, 1, 2, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24], + "dims": [4, 1, 1, 3, 2], + "type": "float32" + } + ] + } + ] } ]