Skip to content

Commit

Permalink
WebGPU JSEP: Add inputs broadcasting into MatMul shader cache key
Browse files Browse the repository at this point in the history
This PR adds inputs broadcasting information into the cache key of
MatMul shaders, which currently impacts the shader code. This PR fixes
the results for MatMul nodes with identical input ranks but different
broadcasting patterns.
  • Loading branch information
jiangzhaoming committed Oct 22, 2024
1 parent 60da4a2 commit 1ee2d5e
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,12 @@ export const createMatmulProgramInfo = (
}
`;
};
// Input broadcasting impacts the shader code, and should be handled in cache key
const inputBroadcastingDims = `${getBroadcastDims(outerDimsA, outerDims)};${getBroadcastDims(outerDimsB, outerDims)}`;
return {
name: 'MatMul',
shaderCache: {
hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`,
hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${inputBroadcastingDims};${isChannelsLast}`,
inputDependencies,
},
getRunData: () => ({
Expand Down
17 changes: 10 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ export const createNaiveMatmulProgramInfo = (
}
programUniforms.push(...createTensorShapeVariables(outputShapeInShader));

const outerDimsA = aShape.slice(0, -2);
const outerDimsB = bShape.slice(0, -2);
const broadcastADims = getBroadcastDims(outerDimsA, outerDims);
const broadcastBDims = getBroadcastDims(outerDimsB, outerDims);

const getShaderSource = (shaderHelper: ShaderHelper) => {
const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length);
const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
Expand All @@ -79,10 +84,6 @@ export const createNaiveMatmulProgramInfo = (
}`;
}

const outerDimsA = aShape.slice(0, -2);
const outerDimsB = bShape.slice(0, -2);
const broadCastADims = getBroadcastDims(outerDimsA, outerDims);
const broadCastBDims = getBroadcastDims(outerDimsB, outerDims);
const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'M', type: 'u32' },
Expand Down Expand Up @@ -141,9 +142,9 @@ export const createNaiveMatmulProgramInfo = (
let batch = index1 / stride1;
${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`}
${getIndices(a, broadCastADims)}
${getIndices(a, broadcastADims)}
let a_offset = ${a.indicesToOffset('a_indices')};
${getIndices(b, broadCastBDims)}
${getIndices(b, broadcastBDims)}
let b_offset = ${b.indicesToOffset('b_indices')};
var values: array<${output.type.value}, ${outputNumber}>;
for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) {
Expand All @@ -160,10 +161,12 @@ export const createNaiveMatmulProgramInfo = (
}
`;
};
// Input broadcasting impacts the shader code, and should be handled in cache key
const inputBroadcastingDims = `${broadcastADims};${broadcastBDims}`;
return {
name: 'MatMulNaive',
shaderCache: {
hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`,
hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${inputBroadcastingDims};${isChannelsLast}`,
inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'],
},
getRunData: () => ({
Expand Down
88 changes: 88 additions & 0 deletions js/web/test/data/ops/matmul.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,94 @@
"type": "float32"
}
]
},
{
"name": "same ranks different broadcast small 0",
"inputs": [
{
"data": [0, 1, 2, 3, 4, 5, 6, 7],
"dims": [1, 2, 2, 2],
"type": "float32"
},
{
"data": [8, 9, 10, 11],
"dims": [2, 1, 2, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [9, 43, 77, 111, 11, 53, 95, 137],
"dims": [2, 2, 2, 1],
"type": "float32"
}
]
},
{
"name": "same ranks different broadcast small 1",
"inputs": [
{
"data": [0, 1, 2, 3, 4, 5, 6, 7],
"dims": [2, 1, 2, 2],
"type": "float32"
},
{
"data": [8, 9, 10, 11],
"dims": [1, 2, 2, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [9, 43, 11, 53, 77, 111, 95, 137],
"dims": [2, 2, 2, 1],
"type": "float32"
}
]
},
{
"name": "same ranks different broadcast larger 0",
"inputs": [
{
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
"dims": [1, 2, 2, 8],
"type": "float32"
},
{
"data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
"dims": [2, 1, 8, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [1036, 3308, 5580, 7852, 1260, 4044, 6828, 9612],
"dims": [2, 2, 2, 1],
"type": "float32"
}
]
},
{
"name": "same ranks different broadcast larger 1",
"inputs": [
{
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
"dims": [2, 1, 2, 8],
"type": "float32"
},
{
"data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
"dims": [1, 2, 8, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [1036, 3308, 1260, 4044, 5580, 7852, 6828, 9612],
"dims": [2, 2, 2, 1],
"type": "float32"
}
]
}
]
}
Expand Down

0 comments on commit 1ee2d5e

Please sign in to comment.