diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index c4c4849dfc494..d8804eebcec55 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -333,13 +333,13 @@ export const createMatMulNBitsBlockwiseProgramInfo = ( ): ProgramInfo => { const inputShape = inputs[0].dims; const aRank = inputShape.length; - const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const nBlocksPerCol = inputs[1].dims[1]; const dimAOuter = inputShape[aRank - 2]; const dimInner = attributes.k; const dimBOuter = attributes.n; const batchDims = inputShape.slice(0, aRank - 2); const batchSize = ShapeUtil.size(batchDims); - const blobSize = (attributes.blockSize / 8) * attributes.bits; + const blobSize = inputs[1].dims[2]; const blobSizeInWords = blobSize / 4; const dataType = inputs[0].dataType; const aComponents = getMaxComponents(attributes.k); @@ -498,10 +498,11 @@ export const createMatMulNBitsBlockwiseProgramInfo = ( export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { validateInputs(context.inputs, attributes); - if (context.inputs.length < 4) { + const nBlocksPerCol = context.inputs[1].dims[1]; + const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes(); + if (context.inputs.length < 4 && nBlocksPerCol < maxComputeWorkgroupSizes[0]) { context.compute(createMatMulNBitsBlockwiseProgramInfo(context.inputs, attributes)); } else { - const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes(); const maxComputeWorkgroupStorageSize = context.getMaxComputeWorkgroupStoragesize(); context.compute( createMatMulNBitsProgramInfo(