Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Oct 24, 2024
1 parent 7a9235f commit 67f5e35
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
3 changes: 2 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,9 @@ export const matMul = (context: ComputeContext): void => {
const batchB = ShapeUtil.size(context.inputs[1].dims.slice(0, -2));
if (batchA !== 1 && M === 1 && batchB === 1) {
const reshapedA = context.inputs[0].reshape([1, batchA, K]);
const reshapedB = context.inputs[0].reshape([1, K, N]);
const matmulOutputShape = [1, batchA, N];
const matmulInputs = [reshapedA, context.inputs[1]]
const matmulInputs = [reshapedA, reshapedB]
context.compute(createMatmulProgramInfo(matmulInputs, { activation: '' }, outputShape, matmulOutputShape), { inputs: matmulInputs });
} else {
context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape));
Expand Down
48 changes: 48 additions & 0 deletions js/web/test/data/ops/matmul.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,54 @@
}
]
},
{
"name": "multiplies 3D tensors with M = 1",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16],
"dims": [6, 1, 4],
"type": "float32"
},
{
"data": [
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24
],
"dims": [1, 4, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [190,200,210,470,496,522,1310,1384,1458,1590,1680,1770,750,792,834,1030,1088,1146],
"dims": [6, 1, 3],
"type": "float32"
}
]
},
{
"name": "multiplies 4D tensors with M = 1",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16],
"dims": [2, 3, 1, 4],
"type": "float32"
},
{
"data": [
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24
],
"dims": [1, 1, 4, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [190,200,210,470,496,522,1310,1384,1458,1590,1680,1770,750,792,834,1030,1088,1146],
"dims": [2, 3, 1, 3],
"type": "float32"
}
]
},
{
"name": "multiplies 4D tensors",
"inputs": [
Expand Down

0 comments on commit 67f5e35

Please sign in to comment.