Skip to content

Commit

Permalink
FP16 Conv, ConvTranspose and MatMul
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 12, 2023
1 parent db558ef commit f060b5a
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 137 deletions.
10 changes: 5 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@

export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid' | 'gelu';

export const typeSnippet = (component: number) => {
export const typeSnippet = (component: number, dataType: string) => {
switch (component) {
case 1:
return 'f32';
return dataType;
case 2:
return 'vec2<f32>';
return `vec2<${dataType}>`;
case 3:
return 'vec3<f32>';
return `vec3<${dataType}>`;
case 4:
return 'vec4<f32>';
return `vec4<${dataType}>`;
default:
throw new Error(`${component}-component is not supported.`);
}
Expand Down
38 changes: 20 additions & 18 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ import {ConvAttributes} from '../conv';
import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
import { tensorTypeToWsglStorageType } from '../common'

const conv2dCommonSnippet =
(isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false,
activation?: Activation, hasPreluActivationWeights = false, innerElementSizeX = 4, innerElementSizeW = 4,
innerElementSize = 4): string => {
innerElementSize = 4, dataType = 'f32'): string => {
const getXSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
Expand Down Expand Up @@ -92,7 +93,7 @@ const conv2dCommonSnippet =
let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0];
let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1];
let xCh = ${col} % inChannels;
var resData = ${typeSnippet(innerElementSizeX)}(0.0);
var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0);
// The bounds checking is always needed since we use it to pad zero for
// the 'same' padding type.
if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) {
Expand All @@ -110,7 +111,7 @@ const conv2dCommonSnippet =
if (row < dimAOuter && col < dimInner) {
${readXSnippet}
}
return ${typeSnippet(innerElementSizeX)}(0.0);`) :
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) :
(fitInner && fitBOuter ? `
let col = colIn * ${innerElementSizeX};
${readXSnippet}` :
Expand All @@ -119,13 +120,13 @@ const conv2dCommonSnippet =
if (row < dimInner && col < dimBOuter) {
${readXSnippet}
}
return ${typeSnippet(innerElementSizeX)}(0.0);`);
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`);

const sampleW = `${getWSnippet(innerElementSizeW)}`;

const resType = typeSnippet(innerElementSize);
const aType = isChannelsLast ? typeSnippet(innerElementSizeX) : typeSnippet(innerElementSizeW);
const bType = isChannelsLast ? typeSnippet(innerElementSizeW) : typeSnippet(innerElementSizeX);
const resType = typeSnippet(innerElementSize, dataType);
const aType = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType);
const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType);
const userCode = `
${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)}
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} {
Expand Down Expand Up @@ -190,23 +191,24 @@ export const createConv2DMatMulProgramInfo =
const fitInner = dimInner % tileInner === 0;

const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];
const t = tensorTypeToWsglStorageType(inputs[0].dataType);

const declareInputs = [
`@group(0) @binding(0) var<storage, read> x: array<${isVec4 && innerElementSize === 4 ? 'vec4<f32>' : 'f32'}>;`,
`@group(0) @binding(1) var<storage, read> w: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`
`@group(0) @binding(0) var<storage, read> x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`,
`@group(0) @binding(1) var<storage, read> w: array<${isVec4 ? `vec4<${t}>` : t}>;`
];
let declareFunctions = `
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? 'vec4<f32>' : 'f32'}) {
result[flatIndex] = ${isVec4 ? 'vec4<f32>' : 'f32'}(value);
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value);
}
fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? 'vec4<f32>' : 'f32'}) {
fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
let flatIndex = getOutputIndexFromCoords(vec4<i32>(d0, d1, d2, d3));
setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value);
}`;
if (hasBias) {
declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`);
declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? `vec4<${t}>` : t}>;`);
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? `vec4<${t}>` : t} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
Expand All @@ -222,7 +224,7 @@ export const createConv2DMatMulProgramInfo =
// dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
${declareInputs.join('')}
@group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${
isVec4 ? 'vec4<f32>' : 'f32'}>;
isVec4 ? `vec4<${t}>` : t}>;
//@group(0) @binding(${declareInputs.length + 1}) var<uniform> uniforms: Uniforms;
const xShape : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
Expand All @@ -240,12 +242,12 @@ export const createConv2DMatMulProgramInfo =
${
conv2dCommonSnippet(
isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0],
elementsSize[1], elementsSize[2])}
elementsSize[1], elementsSize[2], t)}
${
isVec4 ?
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined,
elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined,
sequentialAccessByThreads)}`
};
};
50 changes: 30 additions & 20 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
import {TensorView} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types';
import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from '../common';
import {
getBroadcastDims,
IndicesHelper,
inputVariable,
outputVariable,
ShaderHelper,
tensorTypeToWsglStorageType
} from '../common';
import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils';

import {typeSnippet} from './activation_util';
Expand Down Expand Up @@ -70,8 +77,8 @@ const calculateResultSnippet = (transposeA: boolean, innerElementSize: number) =
};

export const makeMatMulPackedVec4Source =
(workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false,
tileInner = 32, splitK = false, splitedDimInner = 32): string => {
(workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper,
transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32): string => {
const tileAOuter = workgroupSize[1] * workPerThread[1];
const tileBOuter = workgroupSize[0] * workPerThread[0];
const tileAWidth = transposeA ? tileAOuter : tileInner;
Expand All @@ -90,8 +97,8 @@ export const makeMatMulPackedVec4Source =
workPerThread[0]} must be 4.`);
}
return `
var<workgroup> mm_Asub : array<array<vec${innerElementSize}<f32>, ${tileAWidth / innerElementSize}>, ${tileAHight}>;
var<workgroup> mm_Bsub : array<array<vec4<f32>, ${tileBOuter / workPerThread[0]}>, ${tileInner}>;
var<workgroup> mm_Asub : array<array<vec${innerElementSize}<${type}>, ${tileAWidth / innerElementSize}>, ${tileAHight}>;
var<workgroup> mm_Bsub : array<array<vec4<${type}>, ${tileBOuter / workPerThread[0]}>, ${tileInner}>;
const rowPerThread = ${workPerThread[1]};
const colPerThread = ${workPerThread[0]};
Expand All @@ -115,7 +122,7 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'};
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
var acc: array<vec4<f32>, rowPerThread>;
var acc: array<vec4<${type}>, rowPerThread>;
// Loop over shared dimension.
let tileRowB = localRow * ${rowPerThreadB};
Expand Down Expand Up @@ -179,8 +186,9 @@ const readDataFromSubASnippet = (transposeA: boolean) =>
// sequentialAccessByThreads means sequential data in memory is accessed by
// threads, instead of a single thread (default behavior).
export const makeMatMulPackedSource =
(workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false,
tileInner = 32, splitK = false, splitedDimInner = 32, sequentialAccessByThreads = false): string => {
(workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper,
transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32,
sequentialAccessByThreads = false): string => {
const tileAOuter = workPerThread[1] * workgroupSize[1];
const tileBOuter = workPerThread[0] * workgroupSize[0];
const tileAWidth = transposeA ? tileAOuter : tileInner;
Expand Down Expand Up @@ -222,7 +230,7 @@ export const makeMatMulPackedSource =
workgroupBarrier();
// Compute acc values for a single thread.
var BCached : array<f32, colPerThread>;
var BCached : array<${type}, colPerThread>;
for (var k = 0; k < tileInner; k = k + 1) {
for (var inner = 0; inner < colPerThread; inner = inner + 1) {
BCached[inner] = mm_Bsub[k][localCol + inner * ${workgroupSize[0]}];
Expand Down Expand Up @@ -283,7 +291,7 @@ for (var t = 0; t < numTiles; t = t + 1) {
workgroupBarrier();
// Compute acc values for a single thread.
var BCached : array<f32, colPerThread>;
var BCached : array<${type}, colPerThread>;
for (var k = 0; k < tileInner; k = k + 1) {
for (var inner = 0; inner < colPerThread; inner = inner + 1) {
BCached[inner] = mm_Bsub[k][tileCol + inner];
Expand All @@ -309,8 +317,8 @@ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
`;

return `
var<workgroup> mm_Asub : array<array<f32, ${tileAWidth}>, ${tileAHight}>;
var<workgroup> mm_Bsub : array<array<f32, ${tileBOuter}>, ${tileInner}>;
var<workgroup> mm_Asub : array<array<${type}, ${tileAWidth}>, ${tileAHight}>;
var<workgroup> mm_Bsub : array<array<${type}, ${tileBOuter}>, ${tileInner}>;
const rowPerThread = ${workPerThread[1]};
const colPerThread = ${workPerThread[0]};
const tileInner = ${tileInner};
Expand All @@ -324,7 +332,7 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'};
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
var acc : array<array<f32, colPerThread>, rowPerThread>;
var acc : array<array<${type}, colPerThread>, rowPerThread>;
// Without this initialization strange values show up in acc.
for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
Expand All @@ -347,6 +355,7 @@ const matMulReadWriteFnSource =
const outputVariable = variables[5];
const broadCastADims = getBroadcastDims(batchAVariable.shape, batchVariable.shape);
const broadCastBDims = getBroadcastDims(batchBVariable.shape, batchVariable.shape);
const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor);
const getAIndices = () => {
const aRank = aVariable.shape.length;
const batchRank = batchVariable.shape.length;
Expand Down Expand Up @@ -377,8 +386,8 @@ const matMulReadWriteFnSource =
};
const source = `
fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${
typeSnippet(component)} {
var value = ${typeSnippet(component)}(0.0);
typeSnippet(component, dataType)} {
var value = ${typeSnippet(component, dataType)}(0.0);
let col = colIn * ${component};
if(row < dimAOuter && col < dimInner)
{
Expand All @@ -389,8 +398,8 @@ const matMulReadWriteFnSource =
}
fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${
typeSnippet(component)} {
var value = ${typeSnippet(component)}(0.0);
typeSnippet(component, dataType)} {
var value = ${typeSnippet(component, dataType)}(0.0);
let col = colIn * ${component};
if(row < dimInner && col < dimBOuter)
{
Expand All @@ -400,7 +409,7 @@ const matMulReadWriteFnSource =
return value;
}
fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component)}) {
fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) {
let col = colIn * ${component};
if (row < dimAOuter && col < dimBOuter) {
var value = valueIn;
Expand Down Expand Up @@ -444,6 +453,7 @@ export const createMatmulProgramInfo =
Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2])
];

const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const components = isVec4 ? 4 : 1;
const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components);
const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components);
Expand All @@ -466,8 +476,8 @@ export const createMatmulProgramInfo =
${declareFunctions}
${activationFunction}
${
isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, batchDims) :
makeMatMulPackedSource(elementsPerThread, workgroupSize, batchDims)}
isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) :
makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)}
${batchDims.impl()}`;
return {
...metadata,
Expand Down
10 changes: 0 additions & 10 deletions js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types';
Expand Down Expand Up @@ -197,15 +196,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose
if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) {
throw new Error('invalid output shape');
}

// TODO : Need to add support for float64
if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) {
throw new Error('ConvTranspose input(X,W) should be float tensor');
}

if (inputs.length === 3 && inputs[2].dataType !== DataType.float) {
throw new Error('ConvTranspose input(bias) should be float tensor');
}
};

const createConvTranspose2DProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({
Expand Down
10 changes: 0 additions & 10 deletions js/web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {PoolConvUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand Down Expand Up @@ -93,15 +92,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttribute
if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) {
throw new Error('invalid kernel shape');
}

// TODO : Need to add support for float64
if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) {
throw new Error('Conv input(X,W) should be float tensor');
}

if (inputs.length === 3 && inputs[2].dataType !== DataType.float) {
throw new Error('Conv input(bias) should be float tensor');
}
};

const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inputs: readonly TensorView[]): T => {
Expand Down
5 changes: 0 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {BroadcastUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types';
Expand Down Expand Up @@ -35,10 +34,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) {
throw new Error('shared dimension does not match.');
}

if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) {
throw new Error('inputs should be float type');
}
};

export const matMul = (context: ComputeContext): void => {
Expand Down
Loading

0 comments on commit f060b5a

Please sign in to comment.