From e597eaed4afe255b7eda15f57a63a7b399952158 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 19 Nov 2024 04:52:48 +0800 Subject: [PATCH] [js/webgpu] Optimize transpose as reshape when suitable (#22870) BUG #22031 --- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 95 ++++++++++++++++---- js/web/test/data/ops/transpose.jsonc | 24 +++++ 2 files changed, 102 insertions(+), 17 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 1fd99d085e0ed..21225a77b189b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -48,17 +48,61 @@ const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newSh return { newShape, newPerm }; }; +const isTransposeReshape = (perm: number[], shape: readonly number[]) => { + // As long as the dims with values > 1 stay in the same order, it's a reshape. + // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1). + let lastPermutedAxis = 0; + for (let i = 0; i < perm.length; ++i) { + if (shape[perm[i]] === 1) { + continue; + } + if (perm[i] < lastPermutedAxis) { + return false; + } + lastPermutedAxis = perm[i]; + } + return true; +}; + export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => { const inputDataType = inputTensor.dataType; const inputRank = inputTensor.dims.length; const perm = getAdjustedPerm(inputRank, permAttr); const outputShape = getOutputShape(inputTensor.dims, perm); + let newInputShape = inputTensor.dims; + let newOutputShape = outputShape; + const transposeAsReshape = isTransposeReshape(perm, inputTensor.dims); + let getShaderSource; + if (transposeAsReshape) { + getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('input', inputDataType, newInputShape, 4); + const output = outputVariable('output', inputDataType, newOutputShape, 4); + return ` + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + output[global_idx] = input[global_idx]; + }`; + }; + + return { + name: 'TransposeCopy', + shaderCache: { inputDependencies: ['type'] }, + getRunData: () => { + const outputSize = ShapeUtil.size(outputShape); + return { + outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* components */) }, + programUniforms: [{ type: DataType.uint32, data: Math.ceil(outputSize / 4) }], + }; + }, + getShaderSource, + }; + } const { newShape, newPerm } = squeezeShape(inputTensor.dims, perm); const channelsLast = ShapeUtil.areEqual(newPerm, [2, 3, 1]); const channelsFirst = ShapeUtil.areEqual(newPerm, [3, 1, 2]); - const useShared = (newShape.length === 2 && newPerm[0] > newPerm[1]) || channelsLast || channelsFirst; - let newInputShape = useShared ? newShape : inputTensor.dims; - let newOutputShape = outputShape; + const useShared = newShape.length === 2 || channelsLast || channelsFirst; if (useShared) { newInputShape = channelsLast ? [newShape[0], newShape[1] * newShape[2]] @@ -66,13 +110,11 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ? [newShape[0] * newShape[1], newShape[2]] : newShape; newOutputShape = [newInputShape[1], newInputShape[0]]; - } - const input = inputVariable('a', inputDataType, newInputShape.length); - const output = outputVariable('output', inputDataType, newOutputShape.length); - const tileSize = 16; - let getShaderSource; - if (useShared) { - getShaderSource = (shaderHelper: ShaderHelper) => ` + const tileSize = 16; + getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('a', inputDataType, newInputShape.length); + const output = outputVariable('output', inputDataType, newOutputShape.length); + return ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} var tile : array, ${tileSize}>; ${shaderHelper.mainStart([tileSize, tileSize, 1])} @@ -92,8 +134,29 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ${output.setByIndices(`${output.type.indices}(output_row, output_col)`, 'tile[local_id.x][local_id.y]')} } }`; - } else { - getShaderSource = (shaderHelper: ShaderHelper) => ` + }; + return { + name: 'TransposeShared', + shaderCache: { inputDependencies: ['type'] }, + getRunData: () => { + const outputSize = ShapeUtil.size(outputShape); + return { + outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], + dispatchGroup: { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) }, + programUniforms: [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(newInputShape, newOutputShape), + ], + }; + }, + getShaderSource, + }; + } + + getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('a', inputDataType, newInputShape.length); + const output = outputVariable('output', inputDataType, newOutputShape.length); + return ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${permFunctionBody(perm, inputRank, input, output)} @@ -106,17 +169,15 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ${output.setByOffset('global_idx', input.getByIndices('aIndices'))} }`; - } + }; return { - name: useShared ? 'TransposeShared' : 'Transpose', + name: 'Transpose', shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] }, getRunData: () => { const outputSize = ShapeUtil.size(outputShape); return { outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], - dispatchGroup: useShared - ? { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) } - : { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms: [ { type: DataType.uint32, data: outputSize }, ...createTensorShapeVariables(newInputShape, newOutputShape), diff --git a/js/web/test/data/ops/transpose.jsonc b/js/web/test/data/ops/transpose.jsonc index a7265d6444118..d431ceb1712a5 100644 --- a/js/web/test/data/ops/transpose.jsonc +++ b/js/web/test/data/ops/transpose.jsonc @@ -263,6 +263,30 @@ } ] }, + { + "name": "Transpose as reshape - perms:[1, 0, 2, 4, 3]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [1, 0, 2, 4, 3], "type": "ints" }], + "cases": [ + { + "name": "T[3, 1, 2, 1, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [3, 1, 2, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [1, 3, 2, 4, 1], + "type": "float32" + } + ] + } + ] + }, { "name": "Transpose - perms:[1, 0]", "operator": "Transpose",