Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web] FP16 Conv, ConvTranspose and MatMul #17514

Merged
merged 7 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@

export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid' | 'gelu';

export const typeSnippet = (component: number) => {
export const typeSnippet = (component: number, dataType: string) => {
switch (component) {
case 1:
return 'f32';
return dataType;
case 2:
return 'vec2<f32>';
return `vec2<${dataType}>`;
case 3:
return 'vec3<f32>';
return `vec3<${dataType}>`;
case 4:
return 'vec4<f32>';
return `vec4<${dataType}>`;
default:
throw new Error(`${component}-component is not supported.`);
}
Expand Down
42 changes: 23 additions & 19 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types';
import {tensorTypeToWsglStorageType} from '../common';
import {ConvAttributes} from '../conv';

import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
Expand All @@ -32,13 +33,13 @@ import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packe
const conv2dCommonSnippet =
(isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false,
activation?: Activation, hasPreluActivationWeights = false, innerElementSizeX = 4, innerElementSizeW = 4,
innerElementSize = 4): string => {
innerElementSize = 4, dataType = 'f32'): string => {
const getXSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
return 'resData = x[xIndex];';
case 3:
return 'resData = vec3<f32>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);';
return `resData = vec3<${dataType}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`;
case 4:
return 'resData = x[xIndex / 4];';
default:
Expand Down Expand Up @@ -92,7 +93,7 @@ const conv2dCommonSnippet =
let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0];
let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1];
let xCh = ${col} % inChannels;
var resData = ${typeSnippet(innerElementSizeX)}(0.0);
var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0);
// The bounds checking is always needed since we use it to pad zero for
// the 'same' padding type.
if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) {
Expand All @@ -110,7 +111,7 @@ const conv2dCommonSnippet =
if (row < dimAOuter && col < dimInner) {
${readXSnippet}
}
return ${typeSnippet(innerElementSizeX)}(0.0);`) :
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) :
(fitInner && fitBOuter ? `
let col = colIn * ${innerElementSizeX};
${readXSnippet}` :
Expand All @@ -119,13 +120,15 @@ const conv2dCommonSnippet =
if (row < dimInner && col < dimBOuter) {
${readXSnippet}
}
return ${typeSnippet(innerElementSizeX)}(0.0);`);
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`);

const sampleW = `${getWSnippet(innerElementSizeW)}`;

const resType = typeSnippet(innerElementSize);
const aType = isChannelsLast ? typeSnippet(innerElementSizeX) : typeSnippet(innerElementSizeW);
const bType = isChannelsLast ? typeSnippet(innerElementSizeW) : typeSnippet(innerElementSizeX);
const resType = typeSnippet(innerElementSize, dataType);
const aType =
isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType);
const bType =
isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType);
const userCode = `
${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)}
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} {
Expand Down Expand Up @@ -190,23 +193,24 @@ export const createConv2DMatMulProgramInfo =
const fitInner = dimInner % tileInner === 0;

const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];
const t = tensorTypeToWsglStorageType(inputs[0].dataType);

const declareInputs = [
`@group(0) @binding(0) var<storage, read> x: array<${isVec4 && innerElementSize === 4 ? 'vec4<f32>' : 'f32'}>;`,
`@group(0) @binding(1) var<storage, read> w: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`
`@group(0) @binding(0) var<storage, read> x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`,
`@group(0) @binding(1) var<storage, read> w: array<${isVec4 ? `vec4<${t}>` : t}>;`
];
let declareFunctions = `
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? 'vec4<f32>' : 'f32'}) {
result[flatIndex] = ${isVec4 ? 'vec4<f32>' : 'f32'}(value);
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value);
}
fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? 'vec4<f32>' : 'f32'}) {
fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
let flatIndex = getOutputIndexFromCoords(vec4<i32>(d0, d1, d2, d3));
setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value);
}`;
if (hasBias) {
declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`);
declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? `vec4<${t}>` : t}>;`);
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? `vec4<${t}>` : t} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
Expand All @@ -222,7 +226,7 @@ export const createConv2DMatMulProgramInfo =
// dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
${declareInputs.join('')}
@group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${
isVec4 ? 'vec4<f32>' : 'f32'}>;
isVec4 ? `vec4<${t}>` : t}>;
//@group(0) @binding(${declareInputs.length + 1}) var<uniform> uniforms: Uniforms;

const xShape : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
Expand All @@ -240,12 +244,12 @@ export const createConv2DMatMulProgramInfo =
${
conv2dCommonSnippet(
isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0],
elementsSize[1], elementsSize[2])}
elementsSize[1], elementsSize[2], t)}
${
isVec4 ?
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined,
elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined,
sequentialAccessByThreads)}`
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packe
const conv2dTransposeCommonSnippet =
(isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false,
innerElementSize = 4): string => {
const type = typeSnippet(innerElementSize, 'f32');
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
Expand Down Expand Up @@ -89,10 +90,10 @@ const conv2dTransposeCommonSnippet =
let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]);
let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]);
if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) {
return ${typeSnippet(innerElementSize)}(0.0);
return ${type}(0.0);
}
if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) {
return ${typeSnippet(innerElementSize)}(0.0);
return ${type}(0.0);
}
let iXR = i32(xR);
let iXC = i32(xC);
Expand All @@ -105,13 +106,13 @@ const conv2dTransposeCommonSnippet =
if (row < dimAOuter && col < dimInner) {
${readASnippet}
}
return ${typeSnippet(innerElementSize)}(0.0);` :
return ${type}(0.0);` :
`
let col = colIn * ${innerElementSize};
if (row < dimInner && col < dimBOuter) {
${readASnippet}
}
return ${typeSnippet(innerElementSize)}(0.0);`;
return ${type}(0.0);`;

const sampleW = `
let col = colIn * ${innerElementSize};
Expand All @@ -125,21 +126,21 @@ const conv2dTransposeCommonSnippet =
let coord = vec4<i32>(coordX, coordY, col, rowInner);
${getWSnippet(innerElementSize)}
}
return ${typeSnippet(innerElementSize)}(0.0);
return ${type}(0.0);
`;


const userCode = `
${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)}
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} {
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} {
${isChannelsLast ? sampleA : sampleW}
}

fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} {
fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${type} {
${isChannelsLast ? sampleW : sampleA}
}

fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${typeSnippet(innerElementSize)}) {
fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) {
let col = colIn * ${innerElementSize};
if (row < dimAOuter && col < dimBOuter) {
var value = valueInput;
Expand Down Expand Up @@ -234,10 +235,10 @@ export const createConv2DTransposeMatMulProgramInfo =
${declareFunctions}
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)}
${
isVec4 ?
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined,
sequentialAccessByThreads)}`
isVec4 ? makeMatMulPackedVec4Source(
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false,
undefined, sequentialAccessByThreads)}`
};
};
47 changes: 25 additions & 22 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types';
import {inputVariable, outputVariable, ShaderHelper} from '../common';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';

const createConvTranspose2DOpProgramShaderSource =
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes,
outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false): string => {
outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false,
dataType: string): string => {
const isChannelsLast = attributes.format === 'NHWC';
const rowDim = isChannelsLast ? 1 : 2;
const colDim = isChannelsLast ? 2 : 3;
Expand All @@ -39,12 +40,12 @@ const createConvTranspose2DOpProgramShaderSource =
const outputChannelsPerGroup = wShape[1];

let declareFunctions = `
fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? 'vec4<f32>' : 'f32'}) {
result[flatIndex] = ${isVec4 ? 'vec4<f32>' : 'f32'}(value);
fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) {
result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value);
}`;
if (hasBias) {
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<u32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
fn getBiasByOutputCoords(coords : vec4<u32>) -> ${isVec4 ? `vec4<${dataType}>` : dataType} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
Expand All @@ -66,33 +67,33 @@ const createConvTranspose2DOpProgramShaderSource =

// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd: array<vec4<f32>, ${workPerThread}>;
var dotProd: array<vec4<${dataType}>, ${workPerThread}>;
for (var i = 0; i < ${workPerThread}; i++) {
dotProd[i] = vec4<f32>(0.0);
dotProd[i] = vec4<${dataType}>(0.0);
}
for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) {
var dyR = (f32(dyCorner.x) + f32(wR)) / f32(strides.x);
var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x);
let wRPerm = filterDims[0] - 1 - wR;
if (dyR < 0.0 || dyR >= f32(outBackprop[1]) ||
if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) ||
fract(dyR) > 0.0 || wRPerm < 0) {
continue;
}
let idyR: u32 = u32(dyR);

for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) {
let dyC = (f32(dyCorner.y) + f32(wC)) / f32(strides.y);
let dyC2 = (f32(dyCorner.y) + 1.0 + f32(wC)) / f32(strides.y);
let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y);
let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y);
let wCPerm = filterDims[1] - 1 - wC;
if (wCPerm < 0) {
continue;
}
var bDyCVal = true;
var bDyCVal2 = true;
if (dyC < 0.0 || dyC >= f32(outBackprop[2]) ||
if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) ||
fract(dyC) > 0.0) {
bDyCVal = false;
}
if (dyC2 < 0.0 || dyC2 >= f32(outBackprop[2]) ||
if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) ||
fract(dyC2) > 0.0) {
bDyCVal2 = false;
}
Expand All @@ -108,15 +109,15 @@ const createConvTranspose2DOpProgramShaderSource =
let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
let tmpval = vec4<f32>(dot(xValue, wValue0),
let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
dotProd[0] = dotProd[0] + tmpval;

xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')};

dotProd[1] = dotProd[1] + vec4<f32>(dot(xValue, wValue0),
dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
Expand All @@ -130,7 +131,7 @@ const createConvTranspose2DOpProgramShaderSource =
let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
let tmpval = vec4<f32>(dot(xValue, wValue0),
let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
Expand All @@ -145,7 +146,7 @@ const createConvTranspose2DOpProgramShaderSource =
let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')};
let tmpval = vec4<f32>(dot(xValue, wValue0),
let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
Expand Down Expand Up @@ -178,9 +179,9 @@ const createConvTranspose2DOpProgramShaderSource =
if (wR % dilations.x != 0) {
continue;
}
let dyR = (f32(dyRCorner) + f32(wR)) / f32(strides[0]);
let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]);
let wRPerm = filterDims.x - 1 - wR / dilations.x;
if (dyR < 0.0 || dyR >= f32(outBackprop[${rowDim}]) || fract(dyR) > 0.0 ||
if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 ||
wRPerm < 0) {
continue;
}
Expand All @@ -190,9 +191,9 @@ const createConvTranspose2DOpProgramShaderSource =
if (wC % dilations.y != 0) {
continue;
}
let dyC = (f32(dyCCorner) + f32(wC)) / f32(strides.y);
let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y);
let wCPerm = filterDims.y - 1 - wC / dilations.y;
if (dyC < 0.0 || dyC >= f32(outBackprop[${colDim}]) ||
if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) ||
fract(dyC) > 0.0 || wCPerm < 0) {
continue;
}
Expand Down Expand Up @@ -256,6 +257,7 @@ export const createConvTranspose2DProgramInfo =
];
LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`);

const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
return {
...metadata,
outputs: [{
Expand All @@ -265,6 +267,7 @@ export const createConvTranspose2DProgramInfo =
}],
dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}),
getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource(
shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1),
shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false,
dataType),
};
};
Loading
Loading