From 8143b001fd6ea0f0412cbfc90fdb3161b25d6dfe Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 13 Nov 2024 04:37:07 +0800 Subject: [PATCH] [js/webgpu] Optimize ConvTranspose (#22774) BUG #22031 The overall time of ConvTranspose in Demucs model becomes 517.41 ms from 1415.65 ms on my iGPUs. --- .../ops/3rd-party/conv_backprop_webgpu.ts | 313 +++++------------- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 69 +--- 2 files changed, 95 insertions(+), 287 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 2a8756e435b8e..cb1f30ecdd1f4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -29,229 +29,27 @@ import { ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType, + getMaxComponents, } from '../common'; import { ConvTransposeAttributes } from '../conv-transpose'; -const createConvTranspose2DOpProgramShaderSource = ( - shaderHelper: ShaderHelper, - inputs: readonly TensorView[], - outputShape: readonly number[], - hasBias: boolean, - is1DimensionDispatch: boolean, - isVec4 = false, - dataType: string, - uniforms: UniformsArrayType, - isChannelsLast = false, -): string => { - const rowDim = isChannelsLast ? 1 : 2; - const colDim = isChannelsLast ? 2 : 3; - const channelDim = isChannelsLast ? 3 : 1; - const workPerThread = isVec4 ? 2 : 1; - - let declareFunctions = ` - fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { - result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value); - }`; - if (hasBias) { - declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${dataType}>` : dataType} { - return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; - }`; - } - const components = isVec4 ? 4 : 1; - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); - const inputVariables = [dy, w]; - if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); - } - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - - const codeSnippet4 = `{ - let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1]; - let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1]; - let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; - let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; - - let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads); - - // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). - // ? = to be determined. : = across all values in that axis. - var dotProd: array, ${workPerThread}>; - for (var i = 0; i < ${workPerThread}; i++) { - dotProd[i] = vec4<${dataType}>(0.0); - } - for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) { - var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x); - let wRPerm = uniforms.filter_dims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) || - fract(dyR) > 0.0 || wRPerm < 0) { - continue; - } - let idyR: u32 = u32(dyR); - - for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) { - let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); - let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); - let wCPerm = uniforms.filter_dims[1] - 1 - wC; - if (wCPerm < 0) { - continue; - } - var bDyCVal = true; - var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) || - fract(dyC) > 0.0) { - bDyCVal = false; - } - if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) || - fract(dyC2) > 0.0) { - bDyCVal2 = false; - } - - let idyC: u32 = u32(dyC); - let idyC2: u32 = u32(dyC2); - if (bDyCVal && bDyCVal2) { - let d2Length = uniforms.Dy_shape[3]; - for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; - let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; - let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; - let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - - var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - dotProd[0] = dotProd[0] + tmpval; - - xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - - dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - } - } else if (bDyCVal) { - let d2Length = uniforms.Dy_shape[${channelDim}]; - for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; - let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; - let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; - let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - - var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - dotProd[0] = dotProd[0] + tmpval; - } - } else if (bDyCVal2) { - let d2Length = uniforms.Dy_shape[3]; - for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; - let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; - let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; - let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - - var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - let tmpval = vec4<${dataType}>(dot(xValue, wValue0), - dot(xValue, wValue1), - dot(xValue, wValue2), - dot(xValue, wValue3)); - dotProd[1] = dotProd[1] + tmpval; - } - } - } - } - - for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) { - let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : `vec4<${dataType}>(0.0)`}; - ${output.set('batch', 'r', 'c + i', 'd1', 'value')}; - } - }`; - const codeSnippet = ` - let outputIndices = ${output.offsetToIndices('global_idx')}; - let batch = ${output.indicesGet('outputIndices', 0)}; - let d1 = ${output.indicesGet('outputIndices', channelDim)}; - let r = ${output.indicesGet('outputIndices', rowDim)}; - let c = ${output.indicesGet('outputIndices', colDim)}; - let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; - let dyRCorner = dyCorner.x; - let dyCCorner = dyCorner.y; - let groupId = d1 / uniforms.output_channels_per_group; - let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; - // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). - // ? = to be determined. : = across all values in that axis. - var dotProd = ${dataType}(0.0); - for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { - if (wR % uniforms.dilations.x != 0) { - continue; - } - let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); - let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; - if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || - wRPerm < 0) { - continue; - } - let idyR: u32 = u32(dyR); - - for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { - if (wC % uniforms.dilations.y != 0) { - continue; - } - let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); - let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; - if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || - fract(dyC) > 0.0 || wCPerm < 0) { - continue; - } - let idyC: u32 = u32(dyC); - var inputChannel = groupId * uniforms.input_channels_per_group; - for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { - let xValue = ${ - isChannelsLast - ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') - : dy.get('batch', 'inputChannel', 'idyR', 'idyC') - }; - let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; - dotProd = dotProd + xValue * wValue; - inputChannel = inputChannel + 1; - } - } - } - let value = dotProd + ${hasBias ? 'bias[d1]' : `${dataType}(0.0)`}; - ${output.setByOffset('global_idx', 'value')}; - `; - - return ` - ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} - ${declareFunctions} - - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; - ${isVec4 ? codeSnippet4 : codeSnippet}}`; -}; - export const createConvTranspose2DProgramInfo = ( inputs: readonly TensorView[], attributes: ConvTransposeAttributes, squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], ): ProgramInfo => { const hasBias = inputs.length > 2; - // const isChannelsLast = attributes.format === 'NHWC'; const outputShape = attributes.outputShape; - const outputSize = ShapeUtil.size(outputShape); - - // const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - // TODO Enable isVec4 for performance - // Disabled due to weight matrix layout issue - // const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0; + const isChannelsLast = attributes.format === 'NHWC'; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[2] / group; + const outputChannelsPerGroup = wShape[3]; + const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1; + const outputSize = ShapeUtil.size(outputShape) / components; const dispatch = [Math.ceil(outputSize / 64), 1, 1]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); - const isChannelsLast = attributes.format === 'NHWC'; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const strides = [attributes.strides[0], attributes.strides[1]]; const filterDims = [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; @@ -268,15 +66,9 @@ export const createConvTranspose2DProgramInfo = ( ]; const pads = [ effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), - effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2, + effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2), ]; - const isVec4 = false; - const group = attributes.group; - const wShape = inputs[1].dims; - const inputChannelsPerGroup = wShape[0] / group; - const outputChannelsPerGroup = wShape[1]; - const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize }, { type: DataType.uint32, data: strides }, @@ -294,7 +86,6 @@ export const createConvTranspose2DProgramInfo = ( } programUniforms.push(...createTensorShapeVariables(outputShape)); - const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; const getShaderSource = (shaderHelper: ShaderHelper) => { const uniforms: UniformsArrayType = [ { name: 'output_size', type: 'u32' }, @@ -307,21 +98,83 @@ export const createConvTranspose2DProgramInfo = ( { name: 'output_channels_per_group', type: 'u32' }, ]; const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - return `${createConvTranspose2DOpProgramShaderSource( - shaderHelper, - inputs, - outputShape, - hasBias, - is1DimensionDispatch, - isVec4, - dataType, - uniforms, - isChannelsLast, - )}`; + const rowDim = isChannelsLast ? 1 : 2; + const colDim = isChannelsLast ? 2 : 3; + const channelDim = isChannelsLast ? 3 : 1; + + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length); + const inputVariables = [dy, w]; + if (hasBias) { + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); + } + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + + const codeSnippet = ` + let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)}; + let batch = ${output.indicesGet('outputIndices', 0)}; + let d1 = ${output.indicesGet('outputIndices', channelDim)}; + let r = ${output.indicesGet('outputIndices', rowDim)}; + let c = ${output.indicesGet('outputIndices', colDim)}; + let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; + let dyRCorner = dyCorner.x; + let dyCCorner = dyCorner.y; + let groupId = d1 / uniforms.output_channels_per_group; + let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; + // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). + // ? = to be determined. : = across all values in that axis. + var dotProd = ${output.type.value}(0.0); + for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { + if (wR % uniforms.dilations.x != 0) { + continue; + } + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); + let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || + wRPerm < 0) { + continue; + } + let idyR: u32 = u32(dyR); + + for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { + if (wC % uniforms.dilations.y != 0) { + continue; + } + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || + fract(dyC) > 0.0 || wCPerm < 0) { + continue; + } + let idyC: u32 = u32(dyC); + var inputChannel = groupId * uniforms.input_channels_per_group; + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { + let xValue = ${ + isChannelsLast + ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') + : dy.get('batch', 'inputChannel', 'idyR', 'idyC') + }; + let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)}; + let wValue = ${w.getByOffset(`w_offset / ${components}`)}; + dotProd = dotProd + xValue * wValue; + inputChannel = inputChannel + 1; + } + } + } + let value = dotProd${hasBias ? ` + bias[d1 / ${components}]` : ''}; + ${output.setByOffset('global_idx', 'value')}; + `; + + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; + ${codeSnippet}}`; }; + return { name: 'ConvTranspose2D', - shaderCache: { hint: `${attributes.cacheKey};`, inputDependencies }, + shaderCache: { hint: `${attributes.cacheKey};${components}`, inputDependencies }, getRunData: () => ({ dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] }, outputs: [ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 236f1b09a6c93..3e168ddedac86 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -4,7 +4,6 @@ import { TensorView } from '../../tensor-view'; import { ComputeContext } from '../types'; -import { createConv2DTransposeMatMulProgramInfo } from './3rd-party/conv_backprop_mm_webgpu'; import { createConvTranspose2DProgramInfo } from './3rd-party/conv_backprop_webgpu'; import { ConvAttributes } from './conv'; import { parseInternalActivationAttributes } from './fuse-utils'; @@ -227,41 +226,16 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose } }; -// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C] -const weightTransposePerm = [2, 3, 1, 0]; - const convTranspose2d = ( context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[], ): void => { - const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); - const isChannelsLast = attributes.format === 'NHWC'; - const outputShape = adjustedAttributes.outputShape; - const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - // Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's - // not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit - // utilization rate is very low. - if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) { - context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes)); - return; - } - const outHeight = outputShape[isChannelsLast ? 1 : 2]; - const outWidth = outputShape[isChannelsLast ? 2 : 3]; - const weightHeight = inputs[1].dims[2]; - const weightWidth = inputs[1].dims[3]; - - const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; - const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; - const dimInner = weightHeight * weightWidth * inputChannels; - - const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; - // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? - context.compute(createTransposeProgramInfo(inputs[1], weightTransposePerm), { + context.compute(createTransposeProgramInfo(inputs[1], [2, 3, 0, 1]), { inputs: [1], outputs: [attributes.wIsConst ? -2 : -1], })[0]; @@ -271,29 +245,12 @@ const convTranspose2d = ( // STEP.2: prepare reshaped inputs const convTransposeInputs = [inputs[0], transposedWeight]; - const hasBias = inputs.length === 3; - if (hasBias) { - if (!isChannelsLast && inputs[2].dims.length === 1) { - convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); - } else { - convTransposeInputs.push(inputs[2]); - } + if (inputs.length === 3) { + convTransposeInputs.push(inputs[2]); } - - // STEP.3: compute matmul - context.compute( - createConv2DTransposeMatMulProgramInfo( - convTransposeInputs, - adjustedAttributes, - outputShape, - dimAOuter, - dimBOuter, - dimInner, - hasBias, - sequentialAccessByThreads, - ), - { inputs: convTransposeInputs }, - ); + context.compute(createConvTranspose2DProgramInfo(convTransposeInputs, attributes, squeezeOutputShapeFunction), { + inputs: convTransposeInputs, + }); }; const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { @@ -338,12 +295,9 @@ const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttri { ...attributes, pads, strides, dilations, kernelShape }, inputs, ); - context.compute( - createConvTranspose2DProgramInfo(inputs, adjustedAttributes, (outputShape) => - isChannelLast - ? [outputShape[0], outputShape[2], outputShape[3]] - : [outputShape[0], outputShape[1], outputShape[3]], - ), + + convTranspose2d(context, inputs, adjustedAttributes, (outputShape) => + isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]], ); }; @@ -352,6 +306,7 @@ export const convTranspose = (context: ComputeContext, attributes: ConvTranspose if (context.inputs[0].dims.length === 3) { convTranspose1d(context, attributes); } else { - convTranspose2d(context, context.inputs, attributes); + const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, context.inputs); + convTranspose2d(context, context.inputs, adjustedAttributes); } };