From 5353adcde37a118bdd25882482fd584c5ed3f343 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 5 Dec 2023 05:18:37 +0800 Subject: [PATCH] [js/webgpu] Use the naive convTranspose when in/out channels are both 1 (#18658) ### Description With this change, convTranspose with input0 [1, 18, 32, 1], input1 [1, 1, 16, 16] becomes 0.59ms from 6.64ms. --- js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index e880afe09a5d8..32b1d52ed94ca 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -209,18 +209,20 @@ const convTranspose2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); const isChannelsLast = attributes.format === 'NHWC'; - const hasBias = inputs.length === 3; - if (adjustedAttributes.group !== 1) { + const outputShape = adjustedAttributes.outputShape; + const outChannels = outputShape[isChannelsLast ? 3 : 1]; + const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + // Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's + // not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit + // utilization rate is very low. + if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) { context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes)); return; } - const outputShape = adjustedAttributes.outputShape; const outHeight = outputShape[isChannelsLast ? 1 : 2]; const outWidth = outputShape[isChannelsLast ? 2 : 3]; - const outChannels = outputShape[isChannelsLast ? 3 : 1]; const weightHeight = inputs[1].dims[2]; const weightWidth = inputs[1].dims[3]; - const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; @@ -240,6 +242,7 @@ const convTranspose2d = // STEP.2: prepare reshaped inputs const convTransposeInputs = [inputs[0], transposedWeight]; + const hasBias = inputs.length === 3; if (hasBias) { if (!isChannelsLast && inputs[2].dims.length === 1) { convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1]));