Skip to content

Commit

Permalink
[js/webgpu] Enable the NCHW ConvMatMul path (microsoft#17717)
Browse files Browse the repository at this point in the history
1) Enable pointwise NCHW conv2d by MatMul.
2) Enable non-pointwise NCHW conv2d by convMatMul.
3) Fix bug when `sameSize` is true

---------

Co-authored-by: Yulong Wang <[email protected]>
  • Loading branch information
2 people authored and kleiti committed Mar 22, 2024
1 parent fde2d39 commit a07c85f
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 49 deletions.
11 changes: 4 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,14 @@ export const createConv2DMatMulProgramInfo =
const outWidth = isChannelsLast ? outputShape[2] : outputShape[3];
const outHeight = isChannelsLast ? outputShape[1] : outputShape[2];
const outChannels = isChannelsLast ? outputShape[3] : outputShape[1];
const isVec4 = (((inChannels % 4 === 0 || inChannels % 3 === 0) && isChannelsLast) ||
(outWidth % 4 === 0 && !isChannelsLast)) &&
outChannels % 4 === 0;
// TODO: enable vec4 for NCHW
const isVec4 = isChannelsLast && (inChannels % 4 === 0 || inChannels % 3 === 0) && outChannels % 4 === 0;

// TODO: fine tune size
const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight;
const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels;
const workGroupSize: [number, number, number] =
isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1];
const elementsPerThread =
isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1];
const workGroupSize: [number, number, number] = [8, 8, 1];
const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1];
const dispatch = [
Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]),
Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]),
Expand Down
20 changes: 13 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ export const makeMatMulPackedVec4Source =
workPerThread[0]} must be 4.`);
}
return `
var<workgroup> mm_Asub : array<array<vec${innerElementSize}<${type}>, ${tileAWidth / innerElementSize}>, ${tileAHight}>;
var<workgroup> mm_Bsub : array<array<vec4<${type}>, ${tileBOuter / workPerThread[0]}>, ${tileInner}>;
var<workgroup> mm_Asub: array<array<vec${innerElementSize}<${type}>, ${tileAWidth / innerElementSize}>, ${tileAHight}>;
var<workgroup> mm_Bsub: array<array<vec4<${type}>, ${tileBOuter / workPerThread[0]}>, ${tileInner}>;
const rowPerThread = ${workPerThread[1]};
const colPerThread = ${workPerThread[0]};
Expand Down Expand Up @@ -339,7 +339,8 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
};

const matMulReadWriteFnSource =
(component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[]): string => {
(component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[],
isChannelsLast = false): string => {
const batchAVariable = variables[0];
const batchBVariable = variables[1];
const batchVariable = variables[2];
Expand Down Expand Up @@ -407,7 +408,10 @@ const matMulReadWriteFnSource =
if (row < dimAOuter && col < dimBOuter) {
var value = valueIn;
let coords = vec3<i32>(batch, row, colIn);
${hasBias ? 'value = value + bias[colIn];' : ''}
${
hasBias ?
`value = value + ${isChannelsLast ? 'bias[colIn]' : `${typeSnippet(component, dataType)}(bias[row])`};` :
'' }
${applyActivation}
${outputVariable.setByIndices('vec3<u32>(coords)', 'value')}
}
Expand All @@ -418,7 +422,8 @@ const matMulReadWriteFnSource =

export const createMatmulProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes,
outputShape: readonly number[], reshapedOutputShape?: readonly number[]): ProgramInfo => {
outputShape: readonly number[], reshapedOutputShape?: readonly number[],
isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;

Expand Down Expand Up @@ -457,9 +462,10 @@ export const createMatmulProgramInfo =
variables.push(output);
const inputVariables = [A, B];
const hasBias = inputs.length > 2;
const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables);
const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables, isChannelsLast);
if (hasBias) {
inputVariables.push(inputVariable('bias', inputs[2].dataType, [dimBOuter / components], components));
const biasComponents = isChannelsLast ? components : 1;
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims, biasComponents));
}
const getShaderSource = (shaderHelper: ShaderHelper) => `
const dimAOuter: i32 = ${dimAOuter};
Expand Down
76 changes: 43 additions & 33 deletions js/web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,14 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut

// check attributes

const hasBias = inputs.length === 3;
// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
const isChannelsLast = attributes.format === 'NHWC';
if (!isChannelsLast || attributes.group !== 1) {
if (attributes.group !== 1) {
context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes));
return;
}

// const batchSize = context.inputs[0].dims[0];
const isChannelsLast = attributes.format === 'NHWC';
const hasBias = inputs.length === 3;
const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2];
const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3];
const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
Expand All @@ -155,47 +154,59 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
const outHeight = outputShape[isChannelsLast ? 1 : 2];
const outWidth = outputShape[isChannelsLast ? 2 : 3];
const outChannels = outputShape[isChannelsLast ? 3 : 1];
const batch = outputShape[0];

const sameSize =
isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth && attributes.autoPad === 'VALID';
const sameSize = isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth &&
attributes.pads[0] === 0 && attributes.pads[1] === 0;
if (sameSize ||
(weightHeight === 1 && weightWidth === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1 &&
attributes.strides[0] === 1 && attributes.strides[1] === 1 && attributes.pads[0] === 0 &&
attributes.pads[1] === 0)) {
// conv2dByMatMul
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(
{
...transposeProgramMetadata,
cacheHint: weightTransposeAttribute.cacheKey,
get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm)
},
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
}

const batch = outputShape[0];
let xReshaped, wReshaped, matmulOutputShape;
const matmulInputs = [];
matmulInputs.push(inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels]));
matmulInputs.push(transposedWeight.reshape([1, inputChannels, outChannels]));
if (isChannelsLast) {
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(
{
...transposeProgramMetadata,
cacheHint: weightTransposeAttribute.cacheKey,
get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm)
},
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
}
if (sameSize) {
const sharedDim = inputHeight * inputWidth * inputChannels;
xReshaped = inputs[0].reshape([1, batch, sharedDim]);
wReshaped = transposedWeight.reshape([1, sharedDim, outChannels]);
matmulOutputShape = [1, batch, outChannels];
} else {
xReshaped = inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels]);
wReshaped = transposedWeight.reshape([1, inputChannels, outChannels]);
matmulOutputShape = [batch, outHeight * outWidth, outChannels];
}
matmulInputs.push(xReshaped);
matmulInputs.push(wReshaped);
} else {
xReshaped = inputs[0].reshape([batch, inputChannels, inputHeight * inputWidth]);
wReshaped = inputs[1].reshape([1, outChannels, inputChannels]);
matmulOutputShape = [batch, outChannels, outHeight * outWidth];
matmulInputs.push(wReshaped);
matmulInputs.push(xReshaped);
}
if (hasBias) {
matmulInputs.push(inputs[2]);
}
const matmulOutputShape = [batch, outHeight * outWidth, outChannels];
context.compute(
createMatmulProgramInfoLoader(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape),
createMatmulProgramInfoLoader(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
{inputs: matmulInputs});

return;
}

// TODO: implement conv2dWithIm2Col()

const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
const dimInner = weightHeight * weightWidth * inputChannels;

const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true;

// STEP.1: transpose weight
Expand All @@ -214,14 +225,13 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
// STEP.2: prepare reshaped inputs
const convInputs = [inputs[0], transposedWeight];
if (hasBias) {
if (!isChannelsLast && inputs[2].dims.length === 1) {
convInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1]));
} else {
convInputs.push(inputs[2]);
}
convInputs.push(inputs[2]);
}

// STEP.3: compute matmul
const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
const dimInner = weightHeight * weightWidth * inputChannels;
context.compute(
createConv2DMatMulProgramInfoLoader(
convInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias,
Expand Down
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({

export const createMatmulProgramInfoLoader =
(inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[],
reshapedOutputShape?: readonly number[]): ProgramInfoLoader => {
reshapedOutputShape?: readonly number[], isChannelsLast = false): ProgramInfoLoader => {
const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey);
return {
...metadata,
get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes, outputShape, reshapedOutputShape)
get: () => createMatmulProgramInfo(
metadata, inputs, activationAttributes, outputShape, reshapedOutputShape, isChannelsLast)
};
};

Expand Down
36 changes: 36 additions & 0 deletions js/web/test/data/ops/conv.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,42 @@
}
]
},
{
"name": "conv with bias addition C",
"operator": "Conv",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
"dims": [3, 1, 2, 2],
"type": "float32"
},
{
"data": [1, 1, 1, 1, 2, 3, 4, 5],
"dims": [2, 1, 2, 2],
"type": "float32"
},
{
"data": [5, 6],
"dims": [2],
"type": "float32"
}
],
"outputs": [
{
"data": [15, 46, 31, 102, 47, 158],
"dims": [3, 2, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "conv - group - A",
"operator": "Conv",
Expand Down

0 comments on commit a07c85f

Please sign in to comment.