diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 4062709a0114b..f49d94e374d76 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -206,8 +206,9 @@ export const matMul = (context: ComputeContext): void => { const batchB = ShapeUtil.size(context.inputs[1].dims.slice(0, -2)); if (batchA !== 1 && M === 1 && batchB === 1) { const reshapedA = context.inputs[0].reshape([1, batchA, K]); + const reshapedB = context.inputs[0].reshape([1, K, N]); const matmulOutputShape = [1, batchA, N]; - const matmulInputs = [reshapedA, context.inputs[1]] + const matmulInputs = [reshapedA, reshapedB] context.compute(createMatmulProgramInfo(matmulInputs, { activation: '' }, outputShape, matmulOutputShape), { inputs: matmulInputs }); } else { context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); diff --git a/js/web/test/data/ops/matmul.jsonc b/js/web/test/data/ops/matmul.jsonc index 2c2cf509d7e3e..e1aa324a8e0a6 100644 --- a/js/web/test/data/ops/matmul.jsonc +++ b/js/web/test/data/ops/matmul.jsonc @@ -95,6 +95,54 @@ } ] }, + { + "name": "multiplies 3D tensors with M = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [6, 1, 4], + "type": "float32" + }, + { + "data": [ + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 + ], + "dims": [1, 4, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [190,200,210,470,496,522,1310,1384,1458,1590,1680,1770,750,792,834,1030,1088,1146], + "dims": [6, 1, 3], + "type": "float32" + } + ] + }, + { + "name": "multiplies 4D tensors with M = 1", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [2, 3, 1, 4], + "type": "float32" + }, + { + "data": [ + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 + ], + "dims": [1, 1, 4, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [190,200,210,470,496,522,1310,1384,1458,1590,1680,1770,750,792,834,1030,1088,1146], + "dims": [2, 3, 1, 3], + "type": "float32" + } + ] + }, { "name": "multiplies 4D tensors", "inputs": [