Skip to content

Commit

Permalink
[js/webgpu] Refactor matmul conv to support uniforms for matmul (micr…
Browse files Browse the repository at this point in the history
…osoft#18452)

This change refactored matmul/conv related programs to support shape
uniforms. Currently only matmul shape uniforms are fully enabled.
TODOs: add input dependencies for conv related programs, turn clipMax
and clipMin to uniforms.
  • Loading branch information
axinging authored and kleiti committed Mar 22, 2024
1 parent 23b69f3 commit d603b27
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 98 deletions.
73 changes: 39 additions & 34 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@

import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {ProgramInfo} from '../../types';
import {tensorTypeToWsglStorageType} from '../common';
import {ProgramInfo, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
import {ConvAttributes} from '../conv';
import {getActivationSnippet} from '../fuse-utils';

Expand All @@ -50,9 +49,9 @@ const conv2dCommonSnippet =
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
return 'return w[row * wShape[3] + colIn];';
return 'return w[row * i32(uniforms.w_shape[3]) + colIn];';
case 4:
return 'return w[row * wShape[3] / 4 + colIn];';
return 'return w[row * i32(uniforms.w_shape[3]) / 4 + colIn];';
default:
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
}
Expand All @@ -79,13 +78,13 @@ const conv2dCommonSnippet =
col % outWidth);
`;

const xHeight = isChannelsLast ? 'xShape[1]' : 'xShape[2]';
const xWidth = isChannelsLast ? 'xShape[2]' : 'xShape[3]';
const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])';
const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])';
const row = isChannelsLast ? 'row' : 'col';
const col = isChannelsLast ? 'col' : 'row';
const readXSnippet = `
let inChannels = wShape[2];
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
let inChannels = i32(uniforms.w_shape[2]);
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
let outRow = ${row} / outWidth;
let outCol = ${row} % outWidth;
Expand All @@ -99,7 +98,7 @@ const conv2dCommonSnippet =
// the 'same' padding type.
if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) {
${coordASnippet}
let xIndex = getIndexFromCoords4D(coord, xShape);
let xIndex = getIndexFromCoords4D(coord, vec4<i32>(uniforms.x_shape));
${getXSnippet(innerElementSizeX)}
}
return resData;`;
Expand All @@ -109,7 +108,7 @@ const conv2dCommonSnippet =
${readXSnippet}` :
`
let col = colIn * ${innerElementSizeX};
if (row < dimAOuter && col < dimInner) {
if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
${readXSnippet}
}
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) :
Expand All @@ -118,7 +117,7 @@ const conv2dCommonSnippet =
${readXSnippet}` :
`
let col = colIn * ${innerElementSizeX};
if (row < dimInner && col < dimBOuter) {
if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
${readXSnippet}
}
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`);
Expand All @@ -143,10 +142,10 @@ const conv2dCommonSnippet =
fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) {
let col = colIn * ${innerElementSize};
if (row < dimAOuter && col < dimBOuter)
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)
{
var value = valueIn;
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
${coordResSnippet}
${biasSnippet(addBias)}
${applyActivation}
Expand Down Expand Up @@ -194,10 +193,17 @@ export const createConv2DMatMulProgramInfo =
const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];
const t = tensorTypeToWsglStorageType(inputs[0].dataType);

const declareInputs = [
`@group(0) @binding(0) var<storage, read> x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`,
`@group(0) @binding(1) var<storage, read> w: array<${isVec4 ? `vec4<${t}>` : t}>;`
];
// TODO: support component 2, 3.
const components = isVec4 ? 4 : 1;
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
const inputVariables = [x, w];

programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));

let declareFunctions = `
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value);
Expand All @@ -207,41 +213,40 @@ export const createConv2DMatMulProgramInfo =
setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value);
}`;
if (hasBias) {
declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? `vec4<${t}>` : t}>;`);
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);

programUniforms.push(...createTensorShapeVariables(inputs[2].dims));

declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? `vec4<${t}>` : t} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}

const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
programUniforms.push(...createTensorShapeVariables(outputShape));
return {
name: 'Conv2DMatMul',
shaderCache: {hint: attributes.cacheKey},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
programUniforms,
}),
getShaderSource: () => `
${utilFunctions}
getShaderSource: (shaderHelper: ShaderHelper) => `
${utilFunctions('uniforms.result_strides')}
//struct Uniforms { xShape : vec4<i32>, wShape : vec4<i32>, outShape : vec4<i32>,
// outShapeStrides: vec3<i32>, filterDims : vec2<i32>, pad : vec2<i32>, stride : vec2<i32>,
// dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
${declareInputs.join('')}
@group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${
isVec4 ? `vec4<${t}>` : t}>;
//@group(0) @binding(${declareInputs.length + 1}) var<uniform> uniforms: Uniforms;
const xShape : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const wShape : vec4<i32> = vec4<i32>(${inputs[1].dims.join(',')});
const outShape : vec4<i32> = vec4<i32>(${outputShape.join(',')});
const outShapeStrides : vec3<i32> = vec3<i32>(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')});
${
shaderHelper.registerUniform('dimAOuter', 'i32')
.registerUniform('dimBOuter', 'i32')
.registerUniform('dimInner', 'i32')
.declareVariables(...inputVariables, output)}
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]});
const pad : vec2<i32> = vec2<i32>(${attributes.pads[0]}, ${attributes.pads[1]});
const stride : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
const dimAOuter : i32 = ${dimAOuter};
const dimBOuter : i32 = ${dimBOuter};
const dimInner : i32 = ${dimInner};
${declareFunctions}
${
conv2dCommonSnippet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {ProgramInfo} from '../../types';
import {ProgramInfo, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';
import {getActivationSnippet} from '../fuse-utils';

Expand All @@ -36,16 +36,16 @@ const conv2dTransposeCommonSnippet =
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
return 'return W[getIndexFromCoords4D(coord, wShape)];';
return 'return w[getIndexFromCoords4D(coord, vec4<i32>(uniforms.w_shape))];';
case 4:
return `
let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner);
let coord2 = vec4<i32>(coordX, coordY, col + 2, rowInner);
let coord3 = vec4<i32>(coordX, coordY, col + 3, rowInner);
let v0 = W[getIndexFromCoords4D(coord, wShape)];
let v1 = W[getIndexFromCoords4D(coord1, wShape)];
let v2 = W[getIndexFromCoords4D(coord2, wShape)];
let v3 = W[getIndexFromCoords4D(coord3, wShape)];
let v0 = w[getIndexFromCoords4D(coord, vec4<i32>(uniforms.w_shape))];
let v1 = w[getIndexFromCoords4D(coord1, vec4<i32>(uniforms.w_shape))];
let v2 = w[getIndexFromCoords4D(coord2, vec4<i32>(uniforms.w_shape))];
let v3 = w[getIndexFromCoords4D(coord3, vec4<i32>(uniforms.w_shape))];
return vec4<f32>(v0, v1, v2, v3);
`;
default:
Expand Down Expand Up @@ -81,7 +81,7 @@ const conv2dTransposeCommonSnippet =

const readASnippet = `
let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
let outRow = ${row} / outWidth;
let outCol = ${row} % outWidth;
Expand All @@ -99,17 +99,17 @@ const conv2dTransposeCommonSnippet =
let iXC = i32(xC);
let xCh = ${col} % inChannels;
${coordASnippet}
return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`;
return x[getIndexFromCoords4D(coord, vec4<i32>(uniforms.x_shape))/${innerElementSize}];`;

const sampleA = isChannelsLast ? `
let col = colIn * ${innerElementSize};
if (row < dimAOuter && col < dimInner) {
if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
${readASnippet}
}
return ${type}(0.0);` :
`
let col = colIn * ${innerElementSize};
if (row < dimInner && col < dimBOuter) {
if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
${readASnippet}
}
return ${type}(0.0);`;
Expand All @@ -120,8 +120,8 @@ const conv2dTransposeCommonSnippet =
let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels);
let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1];
if (${
isChannelsLast ? 'row < dimInner && col < dimBOuter' :
'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) {
isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' :
'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) {
let rowInner = row % inChannels;
let coord = vec4<i32>(coordX, coordY, col, rowInner);
${getWSnippet(innerElementSize)}
Expand All @@ -142,13 +142,13 @@ const conv2dTransposeCommonSnippet =
fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) {
let col = colIn * ${innerElementSize};
if (row < dimAOuter && col < dimBOuter) {
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
var value = valueInput;
let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
${coordResSnippet}
${biasSnippet(addBias)}
${applyActivation}
result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value;
result[getIndexFromCoords4D(coords, vec4<i32>(uniforms.result_shape))/${innerElementSize}] = value;
}
}`;
return userCode;
Expand Down Expand Up @@ -185,37 +185,46 @@ export const createConv2DTransposeMatMulProgramInfo =

const innerElementSize = isVec4 ? 4 : 1;
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
const components = isVec4 ? 4 : 1;
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1);
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
const inputVariables = [x, w];
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));


const declareInputs = [
`@group(0) @binding(0) var<storage, read> x: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`,
'@group(0) @binding(1) var<storage, read> W: array<f32>;'
];
let declareFunctions = '';
if (hasBias) {
declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`);
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));

declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}

programUniforms.push(...createTensorShapeVariables(outputShape));

return {
name: 'Conv2DTransposeMatMul',
shaderCache: {hint: attributes.cacheKey},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
programUniforms
}),
getShaderSource: () => `
${utilFunctions}
${declareInputs.join('\n')}
@group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${
isVec4 ? 'vec4<f32>' : 'f32'}>;
getShaderSource: (shaderHelper: ShaderHelper) => `
${utilFunctions('uniforms.result_strides')}
${
shaderHelper.registerUniform('dimAOuter', 'i32')
.registerUniform('dimBOuter', 'i32')
.registerUniform('dimInner', 'i32')
.declareVariables(...inputVariables, output)};
const outBackprop : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const xShape : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const wShape : vec4<i32> = vec4<i32>(${inputs[1].dims.join(',')});
const outShape : vec4<i32> = vec4<i32>(${outputShape.join(',')});
const outShapeStrides : vec3<i32> = vec3<i32>(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')});
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
attributes.kernelShape[isChannelsLast ? 2 : 3]});
const effectiveFilterDims : vec2<i32> = filterDims + vec2<i32>(
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
//
// modified to fit the needs of the project

export const utilFunctions = `
export const utilFunctions = (strideStr: string) => (`
fn getIndexFromCoords4D(coords : vec4<i32>, shape : vec4<i32>) -> i32 {
return dot(coords, vec4<i32>(
shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1));
}
fn getOutputIndexFromCoords(coords : vec4<i32>) -> i32 {
return dot(coords, vec4<i32>(
outShapeStrides.x, outShapeStrides.y, outShapeStrides.z, 1));
i32(${strideStr}.x), i32(${strideStr}.y), i32(${strideStr}.z), 1));
}
`;
`);
Loading

0 comments on commit d603b27

Please sign in to comment.