Skip to content

Commit

Permalink
[js/webgpu] Optimize MatMul with M = 1 (microsoft#22577)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
BUG microsoft#22031

In the demucs model, there are lots of MatMul ops with shapes like
below:
`input[0]: [3448,1,512] | float32, input[1]: [512,1536] | float32,
output[0]: [3448,1,1536] | float32`

We can see that for this kind of shape, the batch size is a big value,
but M = 1. Our current algorithm is based on [M, N] to partition tiles,
which is not efficient for such kind of shapes. This PR reshapes the
inputs to improve the matmul performance.
Before:  [3448,1,512] x [512,1536] =  [3448,1,1536]
After: [1, 3448, 512] x [512, 1536] = [1, 3448, 1536] , then the output
can be reshaped to [3448, 1, 1536]

The overall MatMul time in demucs model becomes 1778.45 ms from 4418.17
ms on my iGPUs.

---------

Co-authored-by: Yulong Wang <[email protected]>
  • Loading branch information
2 people authored and ankitm3k committed Dec 11, 2024
1 parent d194236 commit 2827305
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
15 changes: 14 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,19 @@ export const matMul = (context: ComputeContext): void => {
if (N < 8 && K < 8) {
context.compute(createNaiveMatmulProgramInfo(context.inputs, { activation: '' }, outputShape));
} else {
context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape));
const M = outputShape[outputShape.length - 2];
const batchA = ShapeUtil.size(context.inputs[0].dims.slice(0, -2));
const batchB = ShapeUtil.size(context.inputs[1].dims.slice(0, -2));
if (batchA !== 1 && M === 1 && batchB === 1) {
const reshapedA = context.inputs[0].reshape([1, batchA, K]);
const reshapedB = context.inputs[1].reshape([1, K, N]);
const matmulOutputShape = [1, batchA, N];
const matmulInputs = [reshapedA, reshapedB];
context.compute(createMatmulProgramInfo(matmulInputs, { activation: '' }, outputShape, matmulOutputShape), {
inputs: matmulInputs,
});
} else {
context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape));
}
}
};
50 changes: 50 additions & 0 deletions js/web/test/data/ops/matmul.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,56 @@
}
]
},
{
"name": "multiplies 3D tensors with M = 1",
"inputs": [
{
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8
],
"dims": [6, 1, 8],
"type": "float32"
},
{
"data": [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17],
"dims": [1, 8, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [478, 514, 550, 2270, 2434, 2598, 1374, 1474, 1574, 590, 634, 678, 1486, 1594, 1702, 478, 514, 550],
"dims": [6, 1, 3],
"type": "float32"
}
]
},
{
"name": "multiplies 4D tensors with M = 1",
"inputs": [
{
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8
],
"dims": [2, 3, 1, 8],
"type": "float32"
},
{
"data": [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17],
"dims": [1, 1, 8, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [478, 514, 550, 2270, 2434, 2598, 1374, 1474, 1574, 590, 634, 678, 1486, 1594, 1702, 478, 514, 550],
"dims": [2, 3, 1, 3],
"type": "float32"
}
]
},
{
"name": "multiplies 4D tensors",
"inputs": [
Expand Down

0 comments on commit 2827305

Please sign in to comment.