Skip to content

Commit

Permalink
[js/webgpu] Optimize ConvTranspose (#22774)
Browse files Browse the repository at this point in the history
BUG #22031 

The overall time of ConvTranspose in Demucs model becomes 517.41 ms from
1415.65 ms on my iGPUs.
  • Loading branch information
qjia7 authored and guschmue committed Dec 2, 2024
1 parent 09b1531 commit 8143b00
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 287 deletions.
313 changes: 83 additions & 230 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,229 +29,27 @@ import {
ShaderHelper,
tensorTypeToWsglStorageType,
UniformsArrayType,
getMaxComponents,
} from '../common';
import { ConvTransposeAttributes } from '../conv-transpose';

const createConvTranspose2DOpProgramShaderSource = (
shaderHelper: ShaderHelper,
inputs: readonly TensorView[],
outputShape: readonly number[],
hasBias: boolean,
is1DimensionDispatch: boolean,
isVec4 = false,
dataType: string,
uniforms: UniformsArrayType,
isChannelsLast = false,
): string => {
const rowDim = isChannelsLast ? 1 : 2;
const colDim = isChannelsLast ? 2 : 3;
const channelDim = isChannelsLast ? 3 : 1;
const workPerThread = isVec4 ? 2 : 1;

let declareFunctions = `
fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) {
result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value);
}`;
if (hasBias) {
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<u32>) -> ${isVec4 ? `vec4<${dataType}>` : dataType} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
const components = isVec4 ? 4 : 1;
const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components);
const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components);
const inputVariables = [dy, w];
if (hasBias) {
inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components));
}
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);

const codeSnippet4 = `{
let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1];
let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1];
let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread};
let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4;
let dyCorner = vec2<i32>(i32(r), i32(c)) - vec2<i32>(uniforms.pads);
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd: array<vec4<${dataType}>, ${workPerThread}>;
for (var i = 0; i < ${workPerThread}; i++) {
dotProd[i] = vec4<${dataType}>(0.0);
}
for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) {
var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x);
let wRPerm = uniforms.filter_dims[0] - 1 - wR;
if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) ||
fract(dyR) > 0.0 || wRPerm < 0) {
continue;
}
let idyR: u32 = u32(dyR);
for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) {
let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
let wCPerm = uniforms.filter_dims[1] - 1 - wC;
if (wCPerm < 0) {
continue;
}
var bDyCVal = true;
var bDyCVal2 = true;
if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) ||
fract(dyC) > 0.0) {
bDyCVal = false;
}
if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) ||
fract(dyC2) > 0.0) {
bDyCVal2 = false;
}
let idyC: u32 = u32(dyC);
let idyC2: u32 = u32(dyC2);
if (bDyCVal && bDyCVal2) {
let d2Length = uniforms.Dy_shape[3];
for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};
var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
dotProd[0] = dotProd[0] + tmpval;
xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')};
dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
}
} else if (bDyCVal) {
let d2Length = uniforms.Dy_shape[${channelDim}];
for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};
var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
dotProd[0] = dotProd[0] + tmpval;
}
} else if (bDyCVal2) {
let d2Length = uniforms.Dy_shape[3];
for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};
var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')};
let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
dotProd[1] = dotProd[1] + tmpval;
}
}
}
}
for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) {
let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : `vec4<${dataType}>(0.0)`};
${output.set('batch', 'r', 'c + i', 'd1', 'value')};
}
}`;
const codeSnippet = `
let outputIndices = ${output.offsetToIndices('global_idx')};
let batch = ${output.indicesGet('outputIndices', 0)};
let d1 = ${output.indicesGet('outputIndices', channelDim)};
let r = ${output.indicesGet('outputIndices', rowDim)};
let c = ${output.indicesGet('outputIndices', colDim)};
let dyCorner = vec2<i32>(i32(r), i32(c)) - uniforms.pads;
let dyRCorner = dyCorner.x;
let dyCCorner = dyCorner.y;
let groupId = d1 / uniforms.output_channels_per_group;
let wOutChannel = d1 - groupId * uniforms.output_channels_per_group;
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd = ${dataType}(0.0);
for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
if (wR % uniforms.dilations.x != 0) {
continue;
}
let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]);
let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x;
if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 ||
wRPerm < 0) {
continue;
}
let idyR: u32 = u32(dyR);
for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
if (wC % uniforms.dilations.y != 0) {
continue;
}
let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y;
if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) ||
fract(dyC) > 0.0 || wCPerm < 0) {
continue;
}
let idyC: u32 = u32(dyC);
var inputChannel = groupId * uniforms.input_channels_per_group;
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) {
let xValue = ${
isChannelsLast
? dy.get('batch', 'idyR', 'idyC', 'inputChannel')
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
};
let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')};
dotProd = dotProd + xValue * wValue;
inputChannel = inputChannel + 1;
}
}
}
let value = dotProd + ${hasBias ? 'bias[d1]' : `${dataType}(0.0)`};
${output.setByOffset('global_idx', 'value')};
`;

return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${declareFunctions}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')};
${isVec4 ? codeSnippet4 : codeSnippet}}`;
};

export const createConvTranspose2DProgramInfo = (
inputs: readonly TensorView[],
attributes: ConvTransposeAttributes,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const hasBias = inputs.length > 2;
// const isChannelsLast = attributes.format === 'NHWC';
const outputShape = attributes.outputShape;
const outputSize = ShapeUtil.size(outputShape);

// const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
// TODO Enable isVec4 for performance
// Disabled due to weight matrix layout issue
// const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0;
const isChannelsLast = attributes.format === 'NHWC';
const group = attributes.group;
const wShape = inputs[1].dims;
const inputChannelsPerGroup = wShape[2] / group;
const outputChannelsPerGroup = wShape[3];
const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1;
const outputSize = ShapeUtil.size(outputShape) / components;
const dispatch = [Math.ceil(outputSize / 64), 1, 1];
LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`);

const isChannelsLast = attributes.format === 'NHWC';
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
const strides = [attributes.strides[0], attributes.strides[1]];
const filterDims = [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]];
Expand All @@ -268,15 +66,9 @@ export const createConvTranspose2DProgramInfo = (
];
const pads = [
effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2),
effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2,
effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2),
];

const isVec4 = false;
const group = attributes.group;
const wShape = inputs[1].dims;
const inputChannelsPerGroup = wShape[0] / group;
const outputChannelsPerGroup = wShape[1];

const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: strides },
Expand All @@ -294,7 +86,6 @@ export const createConvTranspose2DProgramInfo = (
}
programUniforms.push(...createTensorShapeVariables(outputShape));

const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1;
const getShaderSource = (shaderHelper: ShaderHelper) => {
const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
Expand All @@ -307,21 +98,83 @@ export const createConvTranspose2DProgramInfo = (
{ name: 'output_channels_per_group', type: 'u32' },
];
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
return `${createConvTranspose2DOpProgramShaderSource(
shaderHelper,
inputs,
outputShape,
hasBias,
is1DimensionDispatch,
isVec4,
dataType,
uniforms,
isChannelsLast,
)}`;
const rowDim = isChannelsLast ? 1 : 2;
const colDim = isChannelsLast ? 2 : 3;
const channelDim = isChannelsLast ? 3 : 1;

const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components);
const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length);
const inputVariables = [dy, w];
if (hasBias) {
inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components));
}
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);

const codeSnippet = `
let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)};
let batch = ${output.indicesGet('outputIndices', 0)};
let d1 = ${output.indicesGet('outputIndices', channelDim)};
let r = ${output.indicesGet('outputIndices', rowDim)};
let c = ${output.indicesGet('outputIndices', colDim)};
let dyCorner = vec2<i32>(i32(r), i32(c)) - uniforms.pads;
let dyRCorner = dyCorner.x;
let dyCCorner = dyCorner.y;
let groupId = d1 / uniforms.output_channels_per_group;
let wOutChannel = d1 - groupId * uniforms.output_channels_per_group;
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd = ${output.type.value}(0.0);
for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
if (wR % uniforms.dilations.x != 0) {
continue;
}
let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]);
let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x;
if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 ||
wRPerm < 0) {
continue;
}
let idyR: u32 = u32(dyR);
for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
if (wC % uniforms.dilations.y != 0) {
continue;
}
let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y;
if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) ||
fract(dyC) > 0.0 || wCPerm < 0) {
continue;
}
let idyC: u32 = u32(dyC);
var inputChannel = groupId * uniforms.input_channels_per_group;
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) {
let xValue = ${
isChannelsLast
? dy.get('batch', 'idyR', 'idyC', 'inputChannel')
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
};
let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)};
let wValue = ${w.getByOffset(`w_offset / ${components}`)};
dotProd = dotProd + xValue * wValue;
inputChannel = inputChannel + 1;
}
}
}
let value = dotProd${hasBias ? ` + bias[d1 / ${components}]` : ''};
${output.setByOffset('global_idx', 'value')};
`;

return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')};
${codeSnippet}}`;
};

return {
name: 'ConvTranspose2D',
shaderCache: { hint: `${attributes.cacheKey};`, inputDependencies },
shaderCache: { hint: `${attributes.cacheKey};${components}`, inputDependencies },
getRunData: () => ({
dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] },
outputs: [
Expand Down
Loading

0 comments on commit 8143b00

Please sign in to comment.