Skip to content

Commit

Permalink
Attention WIP, Conv speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 28, 2023
1 parent 1efc5bd commit e797f53
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 280 deletions.
13 changes: 5 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {ConvAttributes} from '../conv';
import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
import { tensorTypeToWsglStorageType } from '../common'
import { tensorTypeToWsglStorageType } from '../common';

const conv2dCommonSnippet =
(isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false,
Expand Down Expand Up @@ -161,17 +161,14 @@ export const createConv2DMatMulProgramInfo =
const outWidth = isChannelsLast ? outputShape[2] : outputShape[3];
const outHeight = isChannelsLast ? outputShape[1] : outputShape[2];
const outChannels = isChannelsLast ? outputShape[3] : outputShape[1];
const isVec4 = (((inChannels % 4 === 0 || inChannels % 3 === 0) && isChannelsLast) ||
(outWidth % 4 === 0 && !isChannelsLast)) &&
outChannels % 4 === 0;
// TODO: enable vec4 for NCHW
const isVec4 = isChannelsLast && (inChannels % 4 === 0 || inChannels % 3 === 0) && outChannels % 4 === 0;

// TODO: fine tune size
const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight;
const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels;
const workGroupSize: [number, number, number] =
isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1];
const elementsPerThread =
isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1];
const workGroupSize: [number, number, number] = [8, 8, 1];
const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1];
const dispatch = [
Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]),
Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]),
Expand Down
237 changes: 120 additions & 117 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ export const makeMatMulPackedVec4Source =
const innerElementSize = tileAWidth / workgroupSize[0];
const rowPerThreadB = tileInner / workgroupSize[1];

if (!(((transposeA && innerElementSize === 4 && workPerThread[1] === 4) ||
(!transposeA && (innerElementSize === 3 || innerElementSize === 4))) &&
tileAWidth % workgroupSize[0] === 0 && tileInner % workgroupSize[1] === 0 && workPerThread[0] === 4)) {
throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${
innerElementSize} and workPerThread[1] ${workPerThread[1]} must be 4.
if (!(((transposeA && innerElementSize === 4 && workPerThread[1] === 4) ||
(!transposeA && (innerElementSize === 3 || innerElementSize === 4))) &&
tileAWidth % workgroupSize[0] === 0 && tileInner % workgroupSize[1] === 0 && workPerThread[0] === 4)) {
throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${
innerElementSize} and workPerThread[1] ${workPerThread[1]} must be 4.
Otherwise, innerElementSize ${innerElementSize} must be 3 or 4.
tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}. tileInner ${
tileInner} must be divisible by workgroupSize[1] ${workgroupSize[1]}. colPerThread ${
Expand Down Expand Up @@ -139,7 +139,7 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
let inputRow = tileRowB + innerRow;
let inputCol = tileCol;
mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol${
batchDims ? ', batchIndices' : ''});
batchDims ? ', batchIndices' : ''});
}
kStart = kStart + tileInner;
workgroupBarrier();
Expand All @@ -161,7 +161,7 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]);
}
}`;
};
};

const writeDataToSubASnippet = (transpose: boolean, batchDims?: IndicesHelper) => {
if (transpose) {
Expand All @@ -181,7 +181,7 @@ const writeDataToSubASnippet = (transpose: boolean, batchDims?: IndicesHelper) =
};

const readDataFromSubASnippet = (transposeA: boolean) =>
transposeA ? 'let ACached = mm_Asub[k][tileRow + innerRow];' : 'let ACached = mm_Asub[tileRow + innerRow][k];';
transposeA ? 'let ACached = mm_Asub[k][tileRow + innerRow];' : 'let ACached = mm_Asub[tileRow + innerRow][k];';

// sequentialAccessByThreads means sequential data in memory is accessed by
// threads, instead of a single thread (default behavior).
Expand All @@ -194,17 +194,17 @@ export const makeMatMulPackedSource =
const tileAWidth = transposeA ? tileAOuter : tileInner;
const tileAHight = transposeA ? tileInner : tileAOuter;

if (!(tileAHight % workgroupSize[1] === 0 && tileAWidth % workgroupSize[0] === 0 &&
tileInner % workgroupSize[1] === 0)) {
throw new Error(`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${
workgroupSize[1]}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${
workgroupSize[0]}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`);
}
const rowPerThreadA = tileAHight / workgroupSize[1];
const colPerThreadA = tileAWidth / workgroupSize[0];
const rowPerThreadB = tileInner / workgroupSize[1];
const matmulSnippet = sequentialAccessByThreads ?
`
if (!(tileAHight % workgroupSize[1] === 0 && tileAWidth % workgroupSize[0] === 0 &&
tileInner % workgroupSize[1] === 0)) {
throw new Error(`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${
workgroupSize[1]}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${
workgroupSize[0]}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`);
}
const rowPerThreadA = tileAHight / workgroupSize[1];
const colPerThreadA = tileAWidth / workgroupSize[0];
const rowPerThreadB = tileInner / workgroupSize[1];
const matmulSnippet = sequentialAccessByThreads ?
`
let localRow = i32(localId.y);
let localCol = i32(localId.x);
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
Expand Down Expand Up @@ -237,8 +237,8 @@ export const makeMatMulPackedSource =
}
for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
let ACached = ${
transposeA ? `mm_Asub[k][localRow + innerRow * ${workgroupSize[1]}];` :
`mm_Asub[localRow + innerRow * ${workgroupSize[1]}][k];`}
transposeA ? `mm_Asub[k][localRow + innerRow * ${workgroupSize[1]}];` :
`mm_Asub[localRow + innerRow * ${workgroupSize[1]}][k];`}
for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {
acc[innerRow][innerCol] = acc[innerRow][innerCol] +
ACached * BCached[innerCol];
Expand All @@ -255,7 +255,7 @@ export const makeMatMulPackedSource =
}
}
` :
`
`
let tileRow = i32(localId.y) * rowPerThread;
let tileCol = i32(localId.x) * colPerThread;
Expand Down Expand Up @@ -343,48 +343,49 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
${matmulSnippet}
}
`;
};
};

const matMulReadWriteFnSource =
(component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[]): string => {
const batchAVariable = variables[0];
const batchBVariable = variables[1];
const batchVariable = variables[2];
const aVariable = variables[3];
const bVariable = variables[4];
const outputVariable = variables[5];
const broadCastADims = getBroadcastDims(batchAVariable.shape, batchVariable.shape);
const broadCastBDims = getBroadcastDims(batchBVariable.shape, batchVariable.shape);
const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor);
const getAIndices = () => {
const aRank = aVariable.shape.length;
const batchRank = batchVariable.shape.length;
let resStr = `var aIndices: ${aVariable.type.indices};`;
for (let i = aRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
resStr += `\naIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`;
}
broadCastADims.forEach(i => {
resStr += `\naIndices[${i}] = 0;`;
});
resStr += `\naIndices[${aRank - 2}] = u32(row);
(component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[], isChannelsLast = false):
string => {
const batchAVariable = variables[0];
const batchBVariable = variables[1];
const batchVariable = variables[2];
const aVariable = variables[3];
const bVariable = variables[4];
const outputVariable = variables[5];
const broadCastADims = getBroadcastDims(batchAVariable.shape, batchVariable.shape);
const broadCastBDims = getBroadcastDims(batchBVariable.shape, batchVariable.shape);
const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor);
const getAIndices = () => {
const aRank = aVariable.shape.length;
const batchRank = batchVariable.shape.length;
let resStr = `var aIndices: ${aVariable.type.indices};`;
for (let i = aRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
resStr += `\naIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`;
}
broadCastADims.forEach(i => {
resStr += `\naIndices[${i}] = 0;`;
});
resStr += `\naIndices[${aRank - 2}] = u32(row);
aIndices[${aRank - 1}] = u32(colIn);`;
return resStr;
};
const getBIndices = () => {
const bRank = bVariable.shape.length;
const batchRank = batchVariable.shape.length;
let resStr = `var bIndices: ${bVariable.type.indices};`;
for (let i = bRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
resStr += `\nbIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`;
}
broadCastBDims.forEach(i => {
resStr += `\nbIndices[${i}] = 0;`;
});
resStr += `\nbIndices[${bRank - 2}] = u32(row);
return resStr;
};
const getBIndices = () => {
const bRank = bVariable.shape.length;
const batchRank = batchVariable.shape.length;
let resStr = `var bIndices: ${bVariable.type.indices};`;
for (let i = bRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
resStr += `\nbIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`;
}
broadCastBDims.forEach(i => {
resStr += `\nbIndices[${i}] = 0;`;
});
resStr += `\nbIndices[${bRank - 2}] = u32(row);
bIndices[${bRank - 1}] = u32(colIn);`;
return resStr;
};
const source = `
return resStr;
};
const source = `
fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${
typeSnippet(component, dataType)} {
var value = ${typeSnippet(component, dataType)}(0.0);
Expand Down Expand Up @@ -414,75 +415,77 @@ const matMulReadWriteFnSource =
if (row < dimAOuter && col < dimBOuter) {
var value = valueIn;
let coords = vec3<i32>(batch, row, colIn);
${hasBias ? 'value = value + bias[colIn];' : ''}
${hasBias ? `value = value + ${isChannelsLast ? 'bias[colIn]' : `${typeSnippet(component, dataType)}(bias[row])`};` : ''}
${applyActivation}
${outputVariable.setByIndices('vec3<u32>(coords)', 'value')}
}
}
`;
return source;
};
return source;
};

export const createMatmulProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes,
outputShape: readonly number[], reshapedOutputShape?: readonly number[]): ProgramInfo => {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;

const outerDimsA = aShape.slice(0, -2);
const outerDimsB = bShape.slice(0, -2);
const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
const batchDims = inputVariable('batchDims', inputs[0].dataType, outerDims);
const batchADims = inputVariable('batchADims', inputs[0].dataType, outerDimsA);
const batchBDims = inputVariable('batchBDims', inputs[0].dataType, outerDimsB);
const variables = [batchADims, batchBDims, batchDims];
const batchSize = ShapeUtil.size(outerDims);

const dimAOuter = aShape[aShape.length - 2];
const dimInner = aShape[aShape.length - 1];
const dimBOuter = bShape[bShape.length - 1];
const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0;
const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes);

// TODO: fine tune size
const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1];
const workgroupSize: [number, number, number] = [8, 8, 1];
const dispatch = [
Math.ceil(dimBOuter / workgroupSize[0] / elementsPerThread[0]),
Math.ceil(dimAOuter / workgroupSize[1] / elementsPerThread[1]),
Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2])
];

const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const components = isVec4 ? 4 : 1;
const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components);
const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components);
const output =
outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components);
variables.push(A);
variables.push(B);
variables.push(output);
const inputVariables = [A, B];
const hasBias = inputs.length > 2;
const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables);
if (hasBias) {
inputVariables.push(inputVariable('bias', inputs[2].dataType, [dimBOuter / components], components));
}
const getShaderSource = (shaderHelper: ShaderHelper) => `
(metadata: ProgramMetadata, inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes,
outputShape: readonly number[], reshapedOutputShape?: readonly number[],
isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;

const outerDimsA = aShape.slice(0, -2);
const outerDimsB = bShape.slice(0, -2);
const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
const batchDims = inputVariable('batchDims', inputs[0].dataType, outerDims);
const batchADims = inputVariable('batchADims', inputs[0].dataType, outerDimsA);
const batchBDims = inputVariable('batchBDims', inputs[0].dataType, outerDimsB);
const variables = [batchADims, batchBDims, batchDims];
const batchSize = ShapeUtil.size(outerDims);

const dimAOuter = aShape[aShape.length - 2];
const dimInner = aShape[aShape.length - 1];
const dimBOuter = bShape[bShape.length - 1];
const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0;
const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes);

// TODO: fine tune size
const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1];
const workgroupSize: [number, number, number] = [8, 8, 1];
const dispatch = [
Math.ceil(dimBOuter / workgroupSize[0] / elementsPerThread[0]),
Math.ceil(dimAOuter / workgroupSize[1] / elementsPerThread[1]),
Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2])
];

const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const components = isVec4 ? 4 : 1;
const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components);
const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components);
const output =
outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components);
variables.push(A);
variables.push(B);
variables.push(output);
const inputVariables = [A, B];
const hasBias = inputs.length > 2;
const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables, isChannelsLast);
if (hasBias) {
const biasComponents = isChannelsLast ? components : 1;
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims, biasComponents));
}
const getShaderSource = (shaderHelper: ShaderHelper) => `
const dimAOuter: i32 = ${dimAOuter};
const dimBOuter: i32 = ${dimBOuter};
const dimInner: i32 = ${dimInner};
${shaderHelper.declareVariables(...inputVariables, output)}
${declareFunctions}
${activationFunction}
${
isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) :
makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)}
isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) :
makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)}
${batchDims.impl()}`;
return {
...metadata,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]})
};
};
return {
...metadata,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]})
};
};
Loading

0 comments on commit e797f53

Please sign in to comment.