Skip to content

Commit

Permalink
[js/webgpu] Fix the transpose error when dims > 4D
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Oct 19, 2023
1 parent 35ecce4 commit 2bcc932
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 22 deletions.
5 changes: 5 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,8 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
}
return dims;
};

// TODO: remove this limitation once >4D dims are supported by uniform.
export const enableShapesUniforms = (rank: number): boolean => {
return rank <= 4;
}

Check notice

Code scanning / CodeQL

Semicolon insertion Note

Avoid automated semicolon insertion (93% of all statements in
the enclosing script
have an explicit semicolon).
51 changes: 29 additions & 22 deletions js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';

import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, useShapesUniforms} from './common';

Check notice

Code scanning / CodeQL

Unused variable, import, function or class Note

Unused import useShapesUniforms.

export interface TransposeAttributes extends AttributeWithCacheKey {
readonly perm: number[];
Expand Down Expand Up @@ -35,13 +35,18 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou
return reverseFunc.join('\n');
};

export const createTransposeProgramInfo =
(inputDataType: number, inputRank: number, permAttr: number[]): ProgramInfo => {
const perm = getAdjustedPerm(inputRank, permAttr);
const output = outputVariable('output', inputDataType, (permAttr && permAttr.length) || inputRank);
const input = inputVariable('a', inputDataType, inputRank);
export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => {
const inputDataType = inputTensor.dataType;
const inputRank = inputTensor.dims.length;
const perm = getAdjustedPerm(inputRank, permAttr);
const useShapesUniforms = enableShapesUniforms(inputRank);
const outputShape = getOutputShape(inputTensor.dims, perm);
const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape;
const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims;
const output = outputVariable('output', inputDataType, outShapeOrRank);
const input = inputVariable('a', inputDataType, inShapeOrRank);

const getShaderSource = (shaderHelper: ShaderHelper) => `
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
${permFunctionBody(perm, inputRank, input, output)}
Expand All @@ -54,30 +59,32 @@ export const createTransposeProgramInfo =
${output.setByOffset('global_idx', input.getByIndices('aIndices'))}
}`;
return {
name: 'Transpose',
shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']},
getRunData: (inputs) => {
const outputSize = ShapeUtil.size(outputShape);
return {
name: 'Transpose',
shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']},
getRunData: (inputs) => {
const outputShape = getOutputShape(inputs[0].dims, perm);
const outputSize = ShapeUtil.size(outputShape);
return {
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: useShapesUniforms ?
[
{type: 'uint32', data: outputSize},
...createTensorShapeVariables(inputs[0].dims),
...createTensorShapeVariables(outputShape),
] :
[
{type: 'uint32', data: outputSize},
],
};
},
getShaderSource,
};
};
},
getShaderSource,
};
};

export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => {
validateInputs(context.inputs);
context.compute(
createTransposeProgramInfo(context.inputs[0].dataType, context.inputs[0].dims.length, attributes.perm));
context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm));
};

export const parseTransposeAttributes = (attributes: Record<string, unknown>): TransposeAttributes =>
Expand Down
24 changes: 24 additions & 0 deletions js/web/test/data/ops/transpose.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,29 @@
]
}
]
},
{
"name": "Transpose 5D - perms:[4, 3, 1, 0, 2]",
"operator": "Transpose",
"attributes": [{ "name": "perm", "data": [4, 3, 1, 0, 2], "type": "ints" }],
"cases": [
{
"name": "T[3, 1, 2, 1, 4]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
"dims": [3, 1, 2, 1, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24],
"dims": [4, 1, 1, 3, 2],
"type": "float32"
}
]
}
]
}
]

0 comments on commit 2bcc932

Please sign in to comment.