Skip to content

Commit

Permalink
switch conv2d to the new matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Oct 18, 2023
1 parent a526d1b commit a3b1e49
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-w
import {ComputeContext} from '../types';

import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu';
import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
import {createGroupedConvProgramInfo} from './conv-grouped';
import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils';
import {createMatmulLWGProgramInfo} from './matmul';
import {createTransposeProgramInfo} from './transpose';

export const calculateOutputShape =
Expand Down Expand Up @@ -196,7 +196,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
matmulInputs.push(inputs[2]);
}
context.compute(
createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
createMatmulLWGProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
{inputs: matmulInputs});
return;
}
Expand Down
17 changes: 9 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ const getMaxComponents = (size: number): 1|2|3|4 => {
return 1;
};
export const createMatmulLWGProgramInfo =
(inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes,
outputShape: readonly number[], reshapedOutputShape?: readonly number[],
(inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[],
reshapedOutputShape?: readonly number[],
isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;
Expand All @@ -38,6 +38,10 @@ export const createMatmulLWGProgramInfo =
const workgroupSize = 64;
// The output number of each thread.
const outputNumber = Math.ceil(tileM / Math.floor(workgroupSize / (tileN / components)));
if (workgroupSize < (tileN / components)) {
throw new Error(
`workgroupSize ${workgroupSize} must be larger than or equal to tileN / components ${tileN / components}`);
}
// The virtualXXX makes sure that one tile of data has the same batch.
const virtualM = Math.ceil(M / tileM) * tileM;
const virtualN = Math.ceil(N / tileN) * tileN;
Expand All @@ -52,8 +56,7 @@ export const createMatmulLWGProgramInfo =
const biasComponents = isChannelsLast ? components : 1;
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims, biasComponents));
processBias = `${
isChannelsLast ? `value += bias[col / ${biasComponents}];` :
`value += ${output.type.value}(bias[row]);`}`;
isChannelsLast ? `value += bias[col / ${biasComponents}];` : `value += ${output.type.value}(bias[row]);`}`;
}
const outerDimsA = aShape.slice(0, -2);
const outerDimsB = bShape.slice(0, -2);
Expand Down Expand Up @@ -194,10 +197,8 @@ export const createMatmulLWGProgramInfo =
return {
name: 'MatMulLinearWG',
shaderCache: {hint: activationAttributes.activationCacheKey},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: numWorkgroups}
}),
getRunData: () =>
({outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: numWorkgroups}}),
getShaderSource,
};
};
Expand Down

0 comments on commit a3b1e49

Please sign in to comment.