From 27a6890529b7bfc379d5bc655632ec320e309ff7 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 23 Aug 2024 13:56:07 +0800 Subject: [PATCH] [js/webgpu] Optimize conv1d by conv2d (#19388) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Optimize conv1d to go to the conv2d path to utilize the conv2d's optimization path. See whisper-tiny-encoder model becomes 158.66 ms from 532.28 ms. Conv goes to Conv2DMatMul(8 ms) instead of GroupedConv(382 ms). Old profiling result: Kernel | Time (ms) | Percentage (%) -- | -- | -- Conv\|GroupedConv | 382.99 | 71.95 MatMul | 126.16 | 23.70 Softmax | 7.01 | 1.32 Transpose | 4.59 | 0.86 Add | 4.39 | 0.82 Mul | 2.36 | 0.44 Div | 1.44 | 0.27 ReduceMean\|ReduceMeanShared | 1.25 | 0.23 Erf | 0.85 | 0.16 Sub | 0.72 | 0.14 Pow | 0.46 | 0.09 Sqrt | 0.07 | 0.01 Sum | 532.28 |   New profiling result with this PR: Kernel | Time (ms) | Percentage (%) -- | -- | -- MatMul | 127.07 | 80.09 Conv\|Conv2DMatMul | 8.00 | 5.04 Softmax | 6.95 | 4.38 Transpose | 4.65 | 2.93 Add | 4.26 | 2.68 Mul | 2.56 | 1.61 Div | 1.51 | 0.95 ReduceMean\|ReduceMeanShared | 1.31 | 0.83 Erf | 0.85 | 0.54 Sub | 0.79 | 0.50 Pow | 0.46 | 0.29 Conv\|Transpose | 0.26 | 0.17 Sqrt | 0.00 | 0.00 Sum | 158.66 |   --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 12 ++-- .../ops/3rd-party/matmul_packed_webgpu.ts | 22 +++--- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 8 ++- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 54 ++++++++++----- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 12 ++-- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 6 +- js/web/test/data/ops/conv1d.jsonc | 69 +++++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + 8 files changed, 141 insertions(+), 43 deletions(-) create mode 100644 js/web/test/data/ops/conv1d.jsonc diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 7884a3cd1a684..3ef5c943d5624 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -182,6 +182,7 @@ export const createConv2DMatMulProgramInfo = ( dimInner: number, hasBias: boolean, sequentialAccessByThreads: boolean, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], ): ProgramInfo => { const isChannelsLast = attributes.format === 'NHWC'; const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; @@ -309,13 +310,16 @@ export const createConv2DMatMulProgramInfo = ( return { name: 'Conv2DMatMul', shaderCache: { - hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${ - tileAOuter - };${tileBOuter};${tileInner}`, + hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${tileAOuter};${tileBOuter};${tileInner}`, inputDependencies, }, getRunData: () => ({ - outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + outputs: [ + { + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType, + }, + ], dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] }, programUniforms, }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index f9bc015055c9f..f0287529ca08b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -110,13 +110,9 @@ export const makeMatMulPackedVec4Source = ( workPerThread[0] === 4 ) ) { - throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${ - innerElementSize - } and workPerThread[1] ${workPerThread[1]} must be 4. + throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${innerElementSize} and workPerThread[1] ${workPerThread[1]} must be 4. Otherwise, innerElementSize ${innerElementSize} must be 3 or 4. - tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}. tileInner ${ - tileInner - } must be divisible by workgroupSize[1] ${workgroupSize[1]}. colPerThread ${workPerThread[0]} must be 4.`); + tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}. tileInner ${tileInner} must be divisible by workgroupSize[1] ${workgroupSize[1]}. colPerThread ${workPerThread[0]} must be 4.`); } return ` var mm_Asub: array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; @@ -227,11 +223,7 @@ export const makeMatMulPackedSource = ( !(tileAHight % workgroupSize[1] === 0 && tileAWidth % workgroupSize[0] === 0 && tileInner % workgroupSize[1] === 0) ) { throw new Error( - `tileAHight ${tileAHight} must be divisible by workgroupSize[1]${ - workgroupSize[1] - }, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${ - workgroupSize[0] - }, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`, + `tileAHight ${tileAHight} must be divisible by workgroupSize[1]${workgroupSize[1]}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`, ); } const rowPerThreadA = tileAHight / workgroupSize[1]; @@ -470,6 +462,7 @@ export const createMatmulProgramInfo = ( outputShape: readonly number[], reshapedOutputShape?: readonly number[], isChannelsLast = false /* only used for conv2dByMatMul*/, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], ): ProgramInfo => { const aShape = inputs[0].dims; const bShape = inputs[1].dims; @@ -562,7 +555,12 @@ export const createMatmulProgramInfo = ( inputDependencies, }, getRunData: () => ({ - outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + outputs: [ + { + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType, + }, + ], dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] }, programUniforms, }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index dbe0e0c9647bd..1ad4149b01e08 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -145,6 +145,7 @@ export const createGroupedConvVectorizeProgramInfo = ( inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[], + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], ): ProgramInfo => { const hasBias = inputs.length > 2; const components = getMaxComponents(outputShape[3]); @@ -234,7 +235,12 @@ export const createGroupedConvVectorizeProgramInfo = ( inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'], }, getRunData: () => ({ - outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + outputs: [ + { + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType, + }, + ], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms, }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 2a32b566ba4ba..241aae8c46603 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -152,9 +152,12 @@ export const parseConvAttributes = (attributes: Record): ConvAt }; }; -const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => { - const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs); - +const conv2d = ( + context: ComputeContext, + inputs: readonly TensorView[], + attributes: ConvAttributes, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], +): void => { // check attributes // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ @@ -177,7 +180,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut inputs[0].dims, inputs[1].dims, attributes.dilations, - adjustedAttributes.pads, + attributes.pads, attributes.strides, isChannelsLast, ); @@ -194,11 +197,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut if (inputs.length === 3) { convInputs.push(inputs[2]); } - context.compute(createGroupedConvVectorizeProgramInfo(convInputs, adjustedAttributes, outputShape), { - inputs: convInputs, - }); + context.compute( + createGroupedConvVectorizeProgramInfo(convInputs, attributes, outputShape, squeezeOutputShapeFunction), + { inputs: convInputs }, + ); } else { - context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes)); + context.compute(createGroupedConvProgramInfo(inputs, attributes, squeezeOutputShapeFunction)); } return; } @@ -214,7 +218,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut inputs[0].dims, inputs[1].dims, attributes.dilations, - adjustedAttributes.pads, + attributes.pads, attributes.strides, isChannelsLast, ); @@ -280,12 +284,26 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // Tune the threshold. if (N < 8 && K < 8) { context.compute( - createNaiveMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), + createNaiveMatmulProgramInfo( + matmulInputs, + attributes, + outputShape, + matmulOutputShape, + isChannelsLast, + squeezeOutputShapeFunction, + ), { inputs: matmulInputs }, ); } else { context.compute( - createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), + createMatmulProgramInfo( + matmulInputs, + attributes, + outputShape, + matmulOutputShape, + isChannelsLast, + squeezeOutputShapeFunction, + ), { inputs: matmulInputs }, ); } @@ -320,13 +338,14 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut context.compute( createConv2DMatMulProgramInfo( convInputs, - adjustedAttributes, + attributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, sequentialAccessByThreads, + squeezeOutputShapeFunction, ), { inputs: convInputs }, ); @@ -357,12 +376,8 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => { { ...attributes, pads, strides, dilations, kernelShape }, inputs, ); - context.compute( - createGroupedConvProgramInfo(inputs, adjustedAttributes, (outputShape) => - isChannelLast - ? [outputShape[0], outputShape[2], outputShape[3]] - : [outputShape[0], outputShape[1], outputShape[3]], - ), + conv2d(context, inputs, adjustedAttributes, (outputShape) => + isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]], ); }; @@ -398,6 +413,7 @@ export const conv = (context: ComputeContext, attributes: ConvAttributes): void } else if (context.inputs[0].dims.length === 5) { conv3d(context, context.inputs, attributes); } else { - conv2d(context, context.inputs, attributes); + const adjustedAttributes = getAdjustedConvAttributes(attributes, context.inputs); + conv2d(context, context.inputs, adjustedAttributes); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index d2a6b2d352e25..7605e67c972b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -32,6 +32,7 @@ export const createNaiveMatmulProgramInfo = ( outputShape: readonly number[], reshapedOutputShape?: readonly number[], isChannelsLast = false /* only used for conv2dByMatMul*/, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], ): ProgramInfo => { const aShape = inputs[0].dims; const bShape = inputs[1].dims; @@ -120,9 +121,7 @@ export const createNaiveMatmulProgramInfo = ( for (let j = 0; j < aComponents; j++) { calcStr += ` - values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${ - i - }]);\n`; + values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${i}]);\n`; } } return calcStr; @@ -168,7 +167,12 @@ export const createNaiveMatmulProgramInfo = ( inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'], }, getRunData: () => ({ - outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + outputs: [ + { + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType, + }, + ], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms, }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 4c1131477cd0f..3c08580128e04 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -83,14 +83,14 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu return { name: 'Transpose', shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] }, - getRunData: (inputs) => { + getRunData: () => { const outputSize = ShapeUtil.size(outputShape); return { - outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms: [ { type: DataType.uint32, data: outputSize }, - ...createTensorShapeVariables(inputs[0].dims, outputShape), + ...createTensorShapeVariables(inputTensor.dims, outputShape), ], }; }, diff --git a/js/web/test/data/ops/conv1d.jsonc b/js/web/test/data/ops/conv1d.jsonc new file mode 100644 index 0000000000000..a387f0de324a6 --- /dev/null +++ b/js/web/test/data/ops/conv1d.jsonc @@ -0,0 +1,69 @@ +[ + { + "name": "conv 1D without bias addition A", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30], + "dims": [1, 1, 3], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [1, 1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [50, 80], + "dims": [1, 1, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "conv 1D with bias addition A", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 2, 2], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [4, 2, 2], + "type": "float32" + }, + { + "data": [0.1, 0.2, 0.3, 0.4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [100.1, 100.2, 100.3, 100.4], + "dims": [1, 4, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 44b89142790ab..edbaeb6f4095c 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1347,6 +1347,7 @@ "concat_zero-sized.jsonc", "cast.jsonc", "conv.jsonc", + "conv1d.jsonc", "conv3dncdhw.jsonc", "cos.jsonc", "div.jsonc",