Skip to content

Commit

Permalink
[js/webgpu] Optimize conv1d by conv2d (#19388)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

Optimize conv1d to go to the conv2d path to utilize the conv2d's
optimization path.

See whisper-tiny-encoder model becomes 158.66 ms from 532.28 ms. Conv
goes to Conv2DMatMul(8 ms) instead of GroupedConv(382 ms).

Old profiling result:
Kernel | Time (ms) | Percentage (%)
-- | -- | --
Conv\|GroupedConv | 382.99 | 71.95
MatMul | 126.16 | 23.70
Softmax | 7.01 | 1.32
Transpose | 4.59 | 0.86
Add | 4.39 | 0.82
Mul | 2.36 | 0.44
Div | 1.44 | 0.27
ReduceMean\|ReduceMeanShared | 1.25 | 0.23
Erf | 0.85 | 0.16
Sub | 0.72 | 0.14
Pow | 0.46 | 0.09
Sqrt | 0.07 | 0.01
Sum | 532.28 |  

New profiling result with this PR:

Kernel | Time (ms) | Percentage (%)
-- | -- | --
MatMul | 127.07 | 80.09
Conv\|Conv2DMatMul | 8.00 | 5.04
Softmax | 6.95 | 4.38
Transpose | 4.65 | 2.93
Add | 4.26 | 2.68
Mul | 2.56 | 1.61
Div | 1.51 | 0.95
ReduceMean\|ReduceMeanShared | 1.31 | 0.83
Erf | 0.85 | 0.54
Sub | 0.79 | 0.50
Pow | 0.46 | 0.29
Conv\|Transpose | 0.26 | 0.17
Sqrt | 0.00 | 0.00
Sum | 158.66 |  

---------

Co-authored-by: Yulong Wang <[email protected]>
  • Loading branch information
qjia7 and fs-eire authored Aug 23, 2024
1 parent 0368dd4 commit 27a6890
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 43 deletions.
12 changes: 8 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ export const createConv2DMatMulProgramInfo = (
dimInner: number,
hasBias: boolean,
sequentialAccessByThreads: boolean,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const isChannelsLast = attributes.format === 'NHWC';
const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1];
Expand Down Expand Up @@ -309,13 +310,16 @@ export const createConv2DMatMulProgramInfo = (
return {
name: 'Conv2DMatMul',
shaderCache: {
hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${
tileAOuter
};${tileBOuter};${tileInner}`,
hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${tileAOuter};${tileBOuter};${tileInner}`,
inputDependencies,
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
outputs: [
{
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
dataType: inputs[0].dataType,
},
],
dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] },
programUniforms,
}),
Expand Down
22 changes: 10 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,9 @@ export const makeMatMulPackedVec4Source = (
workPerThread[0] === 4
)
) {
throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${
innerElementSize
} and workPerThread[1] ${workPerThread[1]} must be 4.
throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${innerElementSize} and workPerThread[1] ${workPerThread[1]} must be 4.
Otherwise, innerElementSize ${innerElementSize} must be 3 or 4.
tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}. tileInner ${
tileInner
} must be divisible by workgroupSize[1] ${workgroupSize[1]}. colPerThread ${workPerThread[0]} must be 4.`);
tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}. tileInner ${tileInner} must be divisible by workgroupSize[1] ${workgroupSize[1]}. colPerThread ${workPerThread[0]} must be 4.`);
}
return `
var<workgroup> mm_Asub: array<array<vec${innerElementSize}<${type}>, ${tileAWidth / innerElementSize}>, ${tileAHight}>;
Expand Down Expand Up @@ -227,11 +223,7 @@ export const makeMatMulPackedSource = (
!(tileAHight % workgroupSize[1] === 0 && tileAWidth % workgroupSize[0] === 0 && tileInner % workgroupSize[1] === 0)
) {
throw new Error(
`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${
workgroupSize[1]
}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${
workgroupSize[0]
}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`,
`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${workgroupSize[1]}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`,
);
}
const rowPerThreadA = tileAHight / workgroupSize[1];
Expand Down Expand Up @@ -470,6 +462,7 @@ export const createMatmulProgramInfo = (
outputShape: readonly number[],
reshapedOutputShape?: readonly number[],
isChannelsLast = false /* only used for conv2dByMatMul*/,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;
Expand Down Expand Up @@ -562,7 +555,12 @@ export const createMatmulProgramInfo = (
inputDependencies,
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
outputs: [
{
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
dataType: inputs[0].dataType,
},
],
dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] },
programUniforms,
}),
Expand Down
8 changes: 7 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ export const createGroupedConvVectorizeProgramInfo = (
inputs: readonly TensorView[],
attributes: ConvAttributes,
outputShape: readonly number[],
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const hasBias = inputs.length > 2;
const components = getMaxComponents(outputShape[3]);
Expand Down Expand Up @@ -234,7 +235,12 @@ export const createGroupedConvVectorizeProgramInfo = (
inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'],
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
outputs: [
{
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
dataType: inputs[0].dataType,
},
],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
Expand Down
54 changes: 35 additions & 19 deletions js/web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,12 @@ export const parseConvAttributes = (attributes: Record<string, unknown>): ConvAt
};
};

const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => {
const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs);

const conv2d = (
context: ComputeContext,
inputs: readonly TensorView[],
attributes: ConvAttributes,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): void => {
// check attributes

// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
Expand All @@ -177,7 +180,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
inputs[0].dims,
inputs[1].dims,
attributes.dilations,
adjustedAttributes.pads,
attributes.pads,
attributes.strides,
isChannelsLast,
);
Expand All @@ -194,11 +197,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
if (inputs.length === 3) {
convInputs.push(inputs[2]);
}
context.compute(createGroupedConvVectorizeProgramInfo(convInputs, adjustedAttributes, outputShape), {
inputs: convInputs,
});
context.compute(
createGroupedConvVectorizeProgramInfo(convInputs, attributes, outputShape, squeezeOutputShapeFunction),
{ inputs: convInputs },
);
} else {
context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes));
context.compute(createGroupedConvProgramInfo(inputs, attributes, squeezeOutputShapeFunction));
}
return;
}
Expand All @@ -214,7 +218,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
inputs[0].dims,
inputs[1].dims,
attributes.dilations,
adjustedAttributes.pads,
attributes.pads,
attributes.strides,
isChannelsLast,
);
Expand Down Expand Up @@ -280,12 +284,26 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
// Tune the threshold.
if (N < 8 && K < 8) {
context.compute(
createNaiveMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
createNaiveMatmulProgramInfo(
matmulInputs,
attributes,
outputShape,
matmulOutputShape,
isChannelsLast,
squeezeOutputShapeFunction,
),
{ inputs: matmulInputs },
);
} else {
context.compute(
createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
createMatmulProgramInfo(
matmulInputs,
attributes,
outputShape,
matmulOutputShape,
isChannelsLast,
squeezeOutputShapeFunction,
),
{ inputs: matmulInputs },
);
}
Expand Down Expand Up @@ -320,13 +338,14 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
context.compute(
createConv2DMatMulProgramInfo(
convInputs,
adjustedAttributes,
attributes,
outputShape,
dimAOuter,
dimBOuter,
dimInner,
hasBias,
sequentialAccessByThreads,
squeezeOutputShapeFunction,
),
{ inputs: convInputs },
);
Expand Down Expand Up @@ -357,12 +376,8 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => {
{ ...attributes, pads, strides, dilations, kernelShape },
inputs,
);
context.compute(
createGroupedConvProgramInfo(inputs, adjustedAttributes, (outputShape) =>
isChannelLast
? [outputShape[0], outputShape[2], outputShape[3]]
: [outputShape[0], outputShape[1], outputShape[3]],
),
conv2d(context, inputs, adjustedAttributes, (outputShape) =>
isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]],
);
};

Expand Down Expand Up @@ -398,6 +413,7 @@ export const conv = (context: ComputeContext, attributes: ConvAttributes): void
} else if (context.inputs[0].dims.length === 5) {
conv3d(context, context.inputs, attributes);
} else {
conv2d(context, context.inputs, attributes);
const adjustedAttributes = getAdjustedConvAttributes(attributes, context.inputs);
conv2d(context, context.inputs, adjustedAttributes);
}
};
12 changes: 8 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export const createNaiveMatmulProgramInfo = (
outputShape: readonly number[],
reshapedOutputShape?: readonly number[],
isChannelsLast = false /* only used for conv2dByMatMul*/,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;
Expand Down Expand Up @@ -120,9 +121,7 @@ export const createNaiveMatmulProgramInfo = (

for (let j = 0; j < aComponents; j++) {
calcStr += `
values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${
i
}]);\n`;
values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${i}]);\n`;
}
}
return calcStr;
Expand Down Expand Up @@ -168,7 +167,12 @@ export const createNaiveMatmulProgramInfo = (
inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'],
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
outputs: [
{
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
dataType: inputs[0].dataType,
},
],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
return {
name: 'Transpose',
shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] },
getRunData: (inputs) => {
getRunData: () => {
const outputSize = ShapeUtil.size(outputShape);
return {
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms: [
{ type: DataType.uint32, data: outputSize },
...createTensorShapeVariables(inputs[0].dims, outputShape),
...createTensorShapeVariables(inputTensor.dims, outputShape),
],
};
},
Expand Down
69 changes: 69 additions & 0 deletions js/web/test/data/ops/conv1d.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
[
{
"name": "conv 1D without bias addition A",
"operator": "Conv",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [{ "name": "kernel_shape", "data": [2], "type": "ints" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [10, 20, 30],
"dims": [1, 1, 3],
"type": "float32"
},
{
"data": [1, 2],
"dims": [1, 1, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [50, 80],
"dims": [1, 1, 2],
"type": "float32"
}
]
}
]
},
{
"name": "conv 1D with bias addition A",
"operator": "Conv",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [{ "name": "kernel_shape", "data": [2], "type": "ints" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [10, 20, 30, 40],
"dims": [1, 2, 2],
"type": "float32"
},
{
"data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
"dims": [4, 2, 2],
"type": "float32"
},
{
"data": [0.1, 0.2, 0.3, 0.4],
"dims": [4],
"type": "float32"
}
],
"outputs": [
{
"data": [100.1, 100.2, 100.3, 100.4],
"dims": [1, 4, 1],
"type": "float32"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,7 @@
"concat_zero-sized.jsonc",
"cast.jsonc",
"conv.jsonc",
"conv1d.jsonc",
"conv3dncdhw.jsonc",
"cos.jsonc",
"div.jsonc",
Expand Down

0 comments on commit 27a6890

Please sign in to comment.