Skip to content

Commit

Permalink
[js/webgpu] Fix the transpose error when dims > 4D (microsoft#18027)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Currently, the uniform support has bugs when dims rank is larger than 4.
See microsoft#17860 item 1.
So this PR only enables shapes uniforms when shape rank is <= 4 for
transpose. Otherwise, below compilation errors are thrown:
```
1 error(s) generated while compiling the shader:
:3:50 error: uniform storage requires that array elements are aligned to 16 bytes, but array element of type 'u32' has a stride of 4 bytes. Consider using a vector or struct as the element type instead.
      struct Uniforms { output_size:u32, a_shape:array<u32, 5>, a_strides:array<u32, 5>, output_shape:array<u32, 5>, output_strides:array<u32, 5> };
                                                 ^^^^^^^^^^^^^

:3:7 note: see layout of struct:
/*            align(4) size(84) */ struct Uniforms {
/* offset( 0) align(4) size( 4) */   output_size : u32;
/* offset( 4) align(4) size(20) */   a_shape : array<u32, 5>;
/* offset(24) align(4) size(20) */   a_strides : array<u32, 5>;
/* offset(44) align(4) size(20) */   output_shape : array<u32, 5>;
/* offset(64) align(4) size(20) */   output_strides : array<u32, 5>;
/*                              */ };
      struct Uniforms { output_size:u32, a_shape:array<u32, 5>, a_strides:array<u32, 5>, output_shape:array<u32, 5>, output_strides:array<u32, 5> };
      ^^^^^^

:4:42 note: 'Uniforms' used in address space 'uniform' here
      @group(0) @binding(2) var<uniform> uniforms: Uniforms;
                                         ^^^^^^^^
```
  • Loading branch information
qjia7 authored Oct 23, 2023
1 parent fc0e933 commit f7244c7
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 25 deletions.
3 changes: 3 additions & 0 deletions web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,6 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
}
return dims;
};

// TODO: remove this limitation once >4D dims are supported by uniform.
export const enableShapesUniforms = (rank: number): boolean => rank <= 4;
2 changes: 1 addition & 1 deletion web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ const convTranspose2d =
// STEP.1: transpose weight
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(
createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposePerm),
createTransposeProgramInfo(inputs[1], weightTransposePerm),
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
Expand Down
4 changes: 2 additions & 2 deletions web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
if (isChannelsLast) {
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(
createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute),
createTransposeProgramInfo(inputs[1], weightTransposeAttribute),
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
Expand Down Expand Up @@ -208,7 +208,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
// STEP.1: transpose weight
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(
createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute),
createTransposeProgramInfo(inputs[1], weightTransposeAttribute),
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
Expand Down
51 changes: 29 additions & 22 deletions web/lib/wasm/jsep/webgpu/ops/transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';

import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';

export interface TransposeAttributes extends AttributeWithCacheKey {
readonly perm: number[];
Expand Down Expand Up @@ -35,13 +35,18 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou
return reverseFunc.join('\n');
};

export const createTransposeProgramInfo =
(inputDataType: number, inputRank: number, permAttr: number[]): ProgramInfo => {
const perm = getAdjustedPerm(inputRank, permAttr);
const output = outputVariable('output', inputDataType, (permAttr && permAttr.length) || inputRank);
const input = inputVariable('a', inputDataType, inputRank);
export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => {
const inputDataType = inputTensor.dataType;
const inputRank = inputTensor.dims.length;
const perm = getAdjustedPerm(inputRank, permAttr);
const useShapesUniforms = enableShapesUniforms(inputRank);
const outputShape = getOutputShape(inputTensor.dims, perm);
const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape;
const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims;
const output = outputVariable('output', inputDataType, outShapeOrRank);
const input = inputVariable('a', inputDataType, inShapeOrRank);

const getShaderSource = (shaderHelper: ShaderHelper) => `
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
${permFunctionBody(perm, inputRank, input, output)}
Expand All @@ -54,30 +59,32 @@ export const createTransposeProgramInfo =
${output.setByOffset('global_idx', input.getByIndices('aIndices'))}
}`;
return {
name: 'Transpose',
shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']},
getRunData: (inputs) => {
const outputSize = ShapeUtil.size(outputShape);
return {
name: 'Transpose',
shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']},
getRunData: (inputs) => {
const outputShape = getOutputShape(inputs[0].dims, perm);
const outputSize = ShapeUtil.size(outputShape);
return {
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: useShapesUniforms ?
[
{type: 'uint32', data: outputSize},
...createTensorShapeVariables(inputs[0].dims),
...createTensorShapeVariables(outputShape),
] :
[
{type: 'uint32', data: outputSize},
],
};
},
getShaderSource,
};
};
},
getShaderSource,
};
};

export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => {
validateInputs(context.inputs);
context.compute(
createTransposeProgramInfo(context.inputs[0].dataType, context.inputs[0].dims.length, attributes.perm));
context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm));
};

export const parseTransposeAttributes = (attributes: Record<string, unknown>): TransposeAttributes =>
Expand Down
24 changes: 24 additions & 0 deletions web/test/data/ops/transpose.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,29 @@
]
}
]
},
{
"name": "Transpose 5D - perms:[4, 3, 1, 0, 2]",
"operator": "Transpose",
"attributes": [{ "name": "perm", "data": [4, 3, 1, 0, 2], "type": "ints" }],
"cases": [
{
"name": "T[3, 1, 2, 1, 4]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
"dims": [3, 1, 2, 1, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24],
"dims": [4, 1, 1, 3, 2],
"type": "float32"
}
]
}
]
}
]

0 comments on commit f7244c7

Please sign in to comment.