Skip to content

Commit

Permalink
Fixed filter setting
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Sep 20, 2023
1 parent 8ed9038 commit 6794407
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@ import {ShapeUtil} from '../../../util';
import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types';
import {ConvTransposeAttributes} from '../conv-transpose';

import {typeSnippet} from './activation_util';
import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';

const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => {
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
return 'return W[getIndexFromCoords4D(coord, wShape)];';
case 4:
return `
const conv2dTransposeCommonSnippet =
(addBias = false, activation?: Activation, hasPreluActivationWeights = false, innerElementSize = 4): string => {
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
return 'return W[getIndexFromCoords4D(coord, wShape)];';
case 4:
return `
let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner);
let coord2 = vec4<i32>(coordX, coordY, col + 2, rowInner);
let coord3 = vec4<i32>(coordX, coordY, col + 3, rowInner);
Expand All @@ -45,12 +46,12 @@ const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => {
let v3 = W[getIndexFromCoords4D(coord3, wShape)];
return vec4<f32>(v0, v1, v2, v3);
`;
default:
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
}
};
default:
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
}
};

const readASnippet = `
const readASnippet = `
let outRow = row / outShape[2];
let outCol = row % outShape[2];
Expand All @@ -71,12 +72,13 @@ const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => {
col % outBackprop[3]);
return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`;

const sampleA = `if (row < dimAOuter && col < dimInner) {
const sampleA = `if (row < dimAOuter && col < dimInner) {
${readASnippet}
}
return ${typeSnippet(innerElementSize)}(0.0);`;

const userCode = `
const userCode = `
${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)}
fn mm_readA(batch: i32, row : i32, col : i32) -> ${typeSnippet(innerElementSize)} {
${sampleA}
}
Expand All @@ -98,16 +100,17 @@ const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => {
fn mm_write(batch: i32, row : i32, col : i32, valueInput : ${typeSnippet(innerElementSize)}) {
if (row < dimAOuter && col < dimBOuter) {
var value = valueInput;
let outCoord = vec4<i32>(
let coords = vec4<i32>(
batch,
row / outShape[2],
row % outShape[2],
col);
result[getIndexFromCoords4D(outCoord, outShape)/${innerElementSize}] = value;
${biasActivationSnippet(addBias, activation)}
result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value;
}
}`;
return userCode;
};
return userCode;
};

export const createConv2DTransposeMatMulProgramInfo =
(inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes,
Expand All @@ -125,46 +128,59 @@ export const createConv2DTransposeMatMulProgramInfo =

const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight;
const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels;
const workGroupSize: [number, number, number] = isVec4 ? [8, 8, 1] : [4, 4, 1];
const workGroupSize: [number, number, number] =
isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1];
const elementsPerThread =
isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1];
const dispatch = [
Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]),
Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]),
Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2])
];
const innerElementSize = isVec4 ? 4 : 1;

LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`);

const innerElementSize = isVec4 ? (inChannels % 4 !== 0 ? 3 : 4) : 1;
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);

LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`);

const declareInputs = [
`@group(0) @binding(0) var<storage, read> x: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`,
`@group(0) @binding(1) var<storage, read> W: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`
'@group(0) @binding(1) var<storage, read> W: array<f32>;'
];

let declareFunctions = '';
if (hasBias) {
declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`);
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
return {
...metadata,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}),
getShaderSource: () => `
${utilFunctions}
${declareInputs.join('')}
${declareInputs.join('\n')}
@group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${
isVec4 ? 'vec4<f32>' : 'f32'}>;
const outBackprop : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const xShape : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
const wShape : vec4<i32> = vec4<i32>(${inputs[1].dims.join(',')});
const outShape : vec4<i32> = vec4<i32>(${outputShape.join(',')});
const outShapeStrides : vec3<i32> = vec3<i32>(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')});
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]});
const pads : vec2<i32> = vec2<i32>(${attributes.pads[0]}, ${attributes.pads[1]});
const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
attributes.kernelShape[isChannelsLast ? 2 : 3]});
const pads : vec2<i32> = vec2<i32>(i32(filterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2,
i32(filterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2);
const strides : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
const dimAOuter : i32 = ${dimAOuter};
const dimBOuter : i32 = ${dimBOuter};
const dimInner : i32 = ${dimInner};
${conv2dTransposeCommonSnippet(innerElementSize)}
${declareFunctions}
${conv2dTransposeCommonSnippet(hasBias, undefined, false, innerElementSize)}
${
isVec4 ?
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) :
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ const getAdjustedConvTransposeAttributes =
<T extends ConvTransposeAttributes>(attributes: T, inputs: readonly TensorView[]): T => {
const kernelShape = attributes.kernelShape.slice();
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 0) === 0) {
if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) {
kernelShape.length = 0;
for (let i = 2; i < inputs[1].dims.length; ++i) {
kernelShape.push(inputs[1].dims[i]);
Expand Down Expand Up @@ -236,7 +236,7 @@ const convTranspose2d =
const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs);
const isChannelsLast = attributes.format === 'NHWC';
const hasBias = inputs.length === 3;
if (adjustedAttributes.group !== 1 || hasBias) {
if (adjustedAttributes.group !== 1 || !isChannelsLast) {
context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes));
return;
}
Expand Down
2 changes: 2 additions & 0 deletions js/web/test/data/ops/conv-transpose.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@
{
"name": "ConvTranspose with bias addition C",
"operator": "ConvTranspose",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }],
"cases": [
{
Expand Down

0 comments on commit 6794407

Please sign in to comment.