From 5fd8ee5b17fbe1ba0827fee7038741013cdd665b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Dec 2024 13:52:25 -0800 Subject: [PATCH] [js/webgpu] fix a bug in transpose shader --- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 21225a77b189b..5059645211aea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -29,7 +29,9 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou let reverseFunc = `fn perm(i: ${output.type.indices}) -> ${input.type.indices} { var a: ${input.type.indices};`; for (let i = 0; i < rank; ++i) { - reverseFunc += input.indicesSet('a', perm[i], `i[${i}]`); + // input indices and output indices should always be larger or equal to 2, + // so indexer is always valid to be used on `a` and `i`. + reverseFunc += `a[${perm[i]}]=i[${i}];`; } return (reverseFunc += 'return a;}'); }; @@ -71,7 +73,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu const outputShape = getOutputShape(inputTensor.dims, perm); let newInputShape = inputTensor.dims; let newOutputShape = outputShape; - const transposeAsReshape = isTransposeReshape(perm, inputTensor.dims); + const transposeAsReshape = inputRank < 2 || isTransposeReshape(perm, inputTensor.dims); let getShaderSource; if (transposeAsReshape) { getShaderSource = (shaderHelper: ShaderHelper) => {