Skip to content

Commit

Permalink
[js/webgpu] fix a bug in transpose shader
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Dec 3, 2024
1 parent 8c52fa3 commit 5fd8ee5
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou
let reverseFunc = `fn perm(i: ${output.type.indices}) -> ${input.type.indices} {
var a: ${input.type.indices};`;
for (let i = 0; i < rank; ++i) {
reverseFunc += input.indicesSet('a', perm[i], `i[${i}]`);
// input indices and output indices should always be larger or equal to 2,
// so indexer is always valid to be used on `a` and `i`.
reverseFunc += `a[${perm[i]}]=i[${i}];`;
}
return (reverseFunc += 'return a;}');
};
Expand Down Expand Up @@ -71,7 +73,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
const outputShape = getOutputShape(inputTensor.dims, perm);
let newInputShape = inputTensor.dims;
let newOutputShape = outputShape;
const transposeAsReshape = isTransposeReshape(perm, inputTensor.dims);
const transposeAsReshape = inputRank < 2 || isTransposeReshape(perm, inputTensor.dims);
let getShaderSource;
if (transposeAsReshape) {
getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down

0 comments on commit 5fd8ee5

Please sign in to comment.