Skip to content

Commit

Permalink
[js/webgpu] Provide a vectorized algorithm for GroupedConv (microsoft…
Browse files Browse the repository at this point in the history
…#18884)

### Description
This PR provides a vectorized algorithm for NHWC GroupedConv to improve
performance.

The aggregate time of GroupedConv in mobilenetv2-12 becomes ~1ms from
~4ms on Intel Alder Lake machine. About 20% improvement for the whole
model.
  • Loading branch information
qjia7 authored Jan 11, 2024
1 parent 622afe1 commit aeeaf36
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 6 deletions.
99 changes: 97 additions & 2 deletions web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ProgramInfo} from '../types';
import {ProgramInfo, ProgramUniform} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';
import {calculateOutputShape, ConvAttributes} from './conv';
import {getActivationSnippet} from './fuse-utils';

Expand Down Expand Up @@ -95,3 +95,98 @@ export const createGroupedConvProgramInfo =
getShaderSource,
};
};

export const createGroupedConvVectorizeProgramInfo =
(inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[]): ProgramInfo => {
const hasBias = inputs.length > 2;
const components = getMaxComponents(outputShape[3]);
const outputNumber = getMaxComponents(outputShape[2]);
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
const xShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[0].dims[2], inputs[0].dims[3] / components];
const wShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[1].dims[3] / components];
const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components];

const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, {type: 'int32', data: attributes.strides},
{type: 'int32', data: attributes.pads}, ...createTensorShapeVariables(xShape),
...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShapeInShader)
];
const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value);
const x = inputVariable('x', inputs[0].dataType, xShape.length, components);
const w = inputVariable('w', inputs[1].dataType, wShape.length, components);
const inputVars = [x, w];
if (hasBias) {
inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components));
}
const processBias = hasBias ? 'value += b[output_channel];' : '';

return `
${
shaderHelper.registerUniform('output_size', 'u32')
.registerUniform('strides', 'i32', 2)
.registerUniform('pads', 'i32', 2)
.declareVariables(...inputVars, output)}
${activationFunction}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let width0 = uniforms.output_shape[3];
let output_channel = global_idx % width0;
var index1 = global_idx / width0;
let width1 = uniforms.output_shape[2] / ${outputNumber}u;
let col = (index1 % width1) * ${outputNumber}u;
index1 = index1 / width1;
let row = index1 % uniforms.output_shape[1];
let batch = index1 / uniforms.output_shape[1];
let x_corner = vec2<i32>(i32(row), i32(col)) * uniforms.strides - uniforms.pads;
var x_vals: array<${x.type.value}, ${xNumber}>;
var values: array<${output.type.value}, ${outputNumber}>;
let input_channel = output_channel;
// Use constant instead of uniform can give better performance for w's height/width.
for (var w_height: u32 = 0u; w_height < ${wShape[0]}; w_height++) {
let x_height = x_corner.x + i32(w_height);
if (x_height >= 0 || u32(x_height) < uniforms.x_shape[1]) {
for (var i = 0; i < ${xNumber}; i++) {
let x_width = x_corner.y + i;
if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) {
x_vals[i] = ${x.get('batch', 'u32(x_height)', 'u32(x_width)', 'input_channel')};
} else {
x_vals[i] = ${x.type.value}(0);
}
}
for (var w_width: u32 = 0u; w_width < ${wShape[1]}; w_width++) {
let w_val = ${w.get('w_height', 'w_width', '0', 'output_channel')};
for (var i = 0u; i < ${outputNumber}u; i++) {
values[i] = fma(x_vals[i * ${attributes.strides[1]}u + w_width], w_val, values[i]);
}
}
}
}
for (var i = 0u; i < ${outputNumber}u; i++) {
var value = values[i];
${processBias}
${applyActivation}
${output.set('batch', 'row', 'col + i', 'output_channel', 'value')};
}
}`;
};

return {
name: 'GroupedConv-Vectorize',
shaderCache: {
hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`,
inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank']
},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource,
};
};
26 changes: 23 additions & 3 deletions web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {ComputeContext} from '../types';

import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu';
import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
import {createGroupedConvProgramInfo} from './conv-grouped';
import {createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo} from './conv-grouped';
import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils';
import {createNaiveMatmulProgramInfo} from './matmul';
import {createTransposeProgramInfo} from './transpose';
Expand Down Expand Up @@ -136,12 +136,32 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
// check attributes

// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
const isChannelsLast = attributes.format === 'NHWC';
if (attributes.group !== 1) {
context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes));
if (isChannelsLast && inputs[1].dims[0] === attributes.group && inputs[1].dims[1] === 1 &&
attributes.dilations[0] === 1 && attributes.dilations[1] === 1) {
const outputShape = calculateOutputShape(
inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides,
isChannelsLast);
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(
createTransposeProgramInfo(inputs[1], weightTransposeAttribute),
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
}
const convInputs = [inputs[0], transposedWeight];
if (inputs.length === 3) {
convInputs.push(inputs[2]);
}
context.compute(
createGroupedConvVectorizeProgramInfo(convInputs, adjustedAttributes, outputShape), {inputs: convInputs});
} else {
context.compute(createGroupedConvProgramInfo(inputs, adjustedAttributes));
}
return;
}

const isChannelsLast = attributes.format === 'NHWC';
const hasBias = inputs.length === 3;
const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2];
const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3];
Expand Down
152 changes: 151 additions & 1 deletion web/test/data/ops/conv.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,157 @@
}
]
},

{
"name": "conv - vectorize group - A",
"operator": "Conv",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [
{ "name": "kernel_shape", "data": [1, 1], "type": "ints" },
{ "name": "group", "data": 2, "type": "int" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
"dims": [1, 2, 3, 3],
"type": "float32"
},
{
"data": [1.0, 2.0],
"dims": [2, 1, 1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0],
"dims": [1, 2, 3, 3],
"type": "float32"
}
]
}
]
},
{
"name": "conv - vectorize group - B",
"operator": "Conv",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "group", "data": 3, "type": "int" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0
],
"dims": [1, 3, 3, 3],
"type": "float32"
},
{
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
"dims": [3, 1, 2, 2],
"type": "float32"
},
{
"data": [0.1, 0.2, 0.3],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [27.1, 37.1, 57.1, 67.1, 293.2, 319.2, 371.2, 397.2, 847.3, 889.3, 409.3, 428.3],
"dims": [1, 3, 2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "conv - vectorize group - C",
"operator": "Conv",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "group", "data": 3, "type": "int" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0
],
"dims": [1, 3, 3, 4],
"type": "float32"
},
{
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
"dims": [3, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [34, 44, 54, 74, 84, 94, 386, 412, 438, 490, 516, 542, 1122, 1164, 1206, 1290, 1332, 1374],
"dims": [1, 3, 2, 3],
"type": "float32"
}
]
}
]
},
{
"name": "conv - vectorize group - D",
"operator": "Conv",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "group", "data": 3, "type": "int" },
{ "name": "strides", "data": [2, 2], "type": "ints" }
],
"cases": [
{
"name": "T[0] strides = [2, 2]",
"inputs": [
{
"data": [
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0
],
"dims": [1, 3, 3, 4],
"type": "float32"
},
{
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
"dims": [3, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [34, 54, 386, 438, 1122, 1206],
"dims": [1, 3, 1, 2],
"type": "float32"
}
]
}
]
},
{
"name": "conv - pointwise",
"operator": "Conv",
Expand Down

0 comments on commit aeeaf36

Please sign in to comment.