Skip to content

Commit

Permalink
[JS/Web] Add ConvTranspose implementation using MatMul (microsoft#17573)
Browse files Browse the repository at this point in the history
### Description
Add ConvTranspose implementation using MatMul to increase perf.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
satyajandhyala authored and kleiti committed Mar 22, 2024
1 parent 6b7e2fc commit a367f56
Show file tree
Hide file tree
Showing 6 changed files with 471 additions and 19 deletions.
243 changes: 243 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts
//
// modified to fit the needs of the project

import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types';
import {ConvTransposeAttributes} from '../conv-transpose';

import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';

const conv2dTransposeCommonSnippet =
(isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false,
innerElementSize = 4): string => {
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
return 'return W[getIndexFromCoords4D(coord, wShape)];';
case 4:
return `
let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner);
let coord2 = vec4<i32>(coordX, coordY, col + 2, rowInner);
let coord3 = vec4<i32>(coordX, coordY, col + 3, rowInner);
let v0 = W[getIndexFromCoords4D(coord, wShape)];
let v1 = W[getIndexFromCoords4D(coord1, wShape)];
let v2 = W[getIndexFromCoords4D(coord2, wShape)];
let v3 = W[getIndexFromCoords4D(coord3, wShape)];
return vec4<f32>(v0, v1, v2, v3);
`;
default:
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
}
};
const coordASnippet = isChannelsLast ? `
let coord = vec4<i32>(batch, iXR, iXC, xCh);
` :
`
let coord = vec4<i32>(batch, xCh, iXR, iXC);
`;

const coordResSnippet = isChannelsLast ? `
let coords = vec4<i32>(
batch,
row / outWidth,
row % outWidth,
col);
` :
`
let coords = vec4<i32>(
batch,
row,
col / outWidth,
col % outWidth);
`;

const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]';
const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]';
const row = isChannelsLast ? 'row' : 'col';
const col = isChannelsLast ? 'col' : 'row';

const readASnippet = `
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
let outRow = ${row} / outWidth;
let outCol = ${row} % outWidth;
let WRow = ${col} / (filterDims[1] * inChannels);
let WCol = ${col} / inChannels % filterDims[1];
let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]);
let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]);
if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) {
return ${typeSnippet(innerElementSize)}(0.0);
}
if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) {
return ${typeSnippet(innerElementSize)}(0.0);
}
let iXR = i32(xR);
let iXC = i32(xC);
let xCh = ${col} % inChannels;
${coordASnippet}
return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`;

const sampleA = isChannelsLast ? `
let col = colIn * ${innerElementSize};
if (row < dimAOuter && col < dimInner) {
${readASnippet}
}
return ${typeSnippet(innerElementSize)}(0.0);` :
`
let col = colIn * ${innerElementSize};
if (row < dimInner && col < dimBOuter) {
${readASnippet}
}
return ${typeSnippet(innerElementSize)}(0.0);`;

const sampleW = `
let col = colIn * ${innerElementSize};
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels);
let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1];
if (${
isChannelsLast ? 'row < dimInner && col < dimBOuter' :
'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) {
let rowInner = row % inChannels;
let coord = vec4<i32>(coordX, coordY, col, rowInner);
${getWSnippet(innerElementSize)}
}
return ${typeSnippet(innerElementSize)}(0.0);
`;


const userCode = `
${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)}
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} {
${isChannelsLast ? sampleA : sampleW}
}
fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} {
${isChannelsLast ? sampleW : sampleA}
}
fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${typeSnippet(innerElementSize)}) {
let col = colIn * ${innerElementSize};
if (row < dimAOuter && col < dimBOuter) {
var value = valueInput;
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
${coordResSnippet}
${biasActivationSnippet(addBias, activation)}
result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value;
}
}`;
return userCode;
};

export const createConv2DTransposeMatMulProgramInfo =
(inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes,
outputShape: readonly number[], dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean,
sequentialAccessByThreads: boolean): ProgramInfo => {
const isChannelsLast = attributes.format === 'NHWC';
const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1];
const batchSize = outputShape[0];
const outWidth = isChannelsLast ? outputShape[2] : outputShape[3];
const outHeight = isChannelsLast ? outputShape[1] : outputShape[2];
const outChannels = isChannelsLast ? outputShape[3] : outputShape[1];
const isVec4 =
isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0;

// TODO: fine tune size
const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight;
const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels;
const workGroupSize: [number, number, number] = isVec4 ?
[8, 8, 1] :
[(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1];
const elementsPerThread =
isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1];
const dispatch = [
Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]),
Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]),
Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2])
];

LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`);

const innerElementSize = isVec4 ? 4 : 1;
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);


const declareInputs = [
`@group(0) @binding(0) var<storage, read> x: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`,
'@group(0) @binding(1) var<storage, read> W: array<f32>;'
];
let declareFunctions = '';
if (hasBias) {
declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`);
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
return {
...metadata,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}),
getShaderSource: () => `
${utilFunctions}
${declareInputs.join('\n')}
@group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${
isVec4 ? 'vec4<f32>' : 'f32'}>;
const outBackprop : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const xShape : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const wShape : vec4<i32> = vec4<i32>(${inputs[1].dims.join(',')});
const outShape : vec4<i32> = vec4<i32>(${outputShape.join(',')});
const outShapeStrides : vec3<i32> = vec3<i32>(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')});
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
attributes.kernelShape[isChannelsLast ? 2 : 3]});
const effectiveFilterDims : vec2<i32> = filterDims + vec2<i32>(
${
attributes.dilations[0] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)},
${
attributes.dilations[1] <= 1 ?
0 :
(attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)});
const pads : vec2<i32> = vec2<i32>(i32(effectiveFilterDims[0]) - 1 - (${
attributes.pads[0] + attributes.pads[2]})/2,
i32(effectiveFilterDims[1]) - 1 - (${
attributes.pads[1] + attributes.pads[3]})/2);
const strides : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
const dimAOuter : i32 = ${dimAOuter};
const dimBOuter : i32 = ${dimBOuter};
const dimInner : i32 = ${dimInner};
${declareFunctions}
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)}
${
isVec4 ?
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined,
sequentialAccessByThreads)}`
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,14 @@ const createConvTranspose2DOpProgramShaderSource =
continue;
}
let idyC: u32 = u32(dyC);
var inputChannel = groupId * ${inputChannelsPerGroup};
for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) {
let inputChannel = groupId * ${inputChannelsPerGroup} + d2;
let xValue = ${
isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') :
dy.get('batch', 'inputChannel', 'idyR', 'idyC')};
let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')};
dotProd = dotProd + xValue * wValue;
inputChannel = inputChannel + 1;
}
}
}
Expand Down
66 changes: 61 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '.

import {createConvTranspose2DProgramInfo} from './3rd-party/conv_backprop_webgpu';
import {ConvAttributes} from './conv';
import {createConv2DTransposeMatMulProgramInfoLoader} from './conv2dtranspose-mm';
import {parseInternalActivationAttributes} from './fuse-utils';
import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose';

const computeTotalPad =
(inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) =>
Expand Down Expand Up @@ -63,7 +65,7 @@ const getAdjustedConvTransposeAttributes =
<T extends ConvTransposeAttributes>(attributes: T, inputs: readonly TensorView[]): T => {
const kernelShape = attributes.kernelShape.slice();
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 0) === 0) {
if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) {
kernelShape.length = 0;
for (let i = 2; i < inputs[1].dims.length; ++i) {
kernelShape.push(inputs[1].dims[i]);
Expand Down Expand Up @@ -95,9 +97,11 @@ const getAdjustedConvTransposeAttributes =

// always return a new object so does not modify the original attributes
const newAttributes: T = Object.assign({}, attributes);
Object.assign(
newAttributes,
{kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey: attributes.cacheKey});
const cacheKey = attributes.cacheKey + [
kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','),
dilations.join(',')
].join('_');
Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey});
return newAttributes;
};

Expand Down Expand Up @@ -226,12 +230,64 @@ const createConvTranspose2DProgramInfoLoader =
};
};

// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C]
const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [2, 3, 1, 0]});

const convTranspose2d =
(context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => {
const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs);
const isChannelsLast = attributes.format === 'NHWC';
const hasBias = inputs.length === 3;
if (adjustedAttributes.group !== 1) {
context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes));
return;
}
const outputShape = adjustedAttributes.outputShape;
const outHeight = outputShape[isChannelsLast ? 1 : 2];
const outWidth = outputShape[isChannelsLast ? 2 : 3];
const outChannels = outputShape[isChannelsLast ? 3 : 1];
const weightHeight = inputs[1].dims[2];
const weightWidth = inputs[1].dims[3];
const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];

const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
const dimInner = weightHeight * weightWidth * inputChannels;

const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true;


context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes));
// STEP.1: transpose weight
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(
{
...transposeProgramMetadata,
cacheHint: weightTransposeAttribute.cacheKey,
get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm)
},
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
}

// STEP.2: prepare reshaped inputs
const convTransposeInputs = [inputs[0], transposedWeight];
if (hasBias) {
if (!isChannelsLast && inputs[2].dims.length === 1) {
convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1]));
} else {
convTransposeInputs.push(inputs[2]);
}
}

// STEP.3: compute matmul
context.compute(
createConv2DTransposeMatMulProgramInfoLoader(
convTransposeInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias,
sequentialAccessByThreads),
{inputs: convTransposeInputs});
};

const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => {
// extend the input to 2D by adding H dimension
const isChannelLast = attributes.format === 'NHWC';
Expand Down
29 changes: 29 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {TensorView} from '../../tensor-view';
import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types';

import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu';
import {ConvTransposeAttributes} from './conv-transpose';


const createConv2DTransposeMatMulProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({
name: 'Conv2DTransposeMatMul',
inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] :
[GpuDataType.default, GpuDataType.default],
cacheHint
});

export const createConv2DTransposeMatMulProgramInfoLoader =
(inputs: readonly TensorView[], attributes: ConvTransposeAttributes, outputShape: readonly number[],
dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean,
sequentialAccessByThreads: boolean): ProgramInfoLoader => {
const metadata = createConv2DTransposeMatMulProgramMetadata(hasBias, attributes.cacheKey);
return {
...metadata,
get: () => createConv2DTransposeMatMulProgramInfo(
inputs, metadata, attributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias,
sequentialAccessByThreads)
};
};
Loading

0 comments on commit a367f56

Please sign in to comment.