From 2bf70efe5f1d5b0905a680886c2169bdcc5eeab9 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 19 Aug 2024 15:07:51 +0800 Subject: [PATCH] add limitations --- js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index c4c4849dfc494..d8804eebcec55 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -333,13 +333,13 @@ export const createMatMulNBitsBlockwiseProgramInfo = ( ): ProgramInfo => { const inputShape = inputs[0].dims; const aRank = inputShape.length; - const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const nBlocksPerCol = inputs[1].dims[1]; const dimAOuter = inputShape[aRank - 2]; const dimInner = attributes.k; const dimBOuter = attributes.n; const batchDims = inputShape.slice(0, aRank - 2); const batchSize = ShapeUtil.size(batchDims); - const blobSize = (attributes.blockSize / 8) * attributes.bits; + const blobSize = inputs[1].dims[2]; const blobSizeInWords = blobSize / 4; const dataType = inputs[0].dataType; const aComponents = getMaxComponents(attributes.k); @@ -498,10 +498,11 @@ export const createMatMulNBitsBlockwiseProgramInfo = ( export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { validateInputs(context.inputs, attributes); - if (context.inputs.length < 4) { + const nBlocksPerCol = context.inputs[1].dims[1]; + const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes(); + if (context.inputs.length < 4 && nBlocksPerCol < maxComputeWorkgroupSizes[0]) { context.compute(createMatMulNBitsBlockwiseProgramInfo(context.inputs, attributes)); } else { - const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes(); const maxComputeWorkgroupStorageSize = context.getMaxComputeWorkgroupStoragesize(); context.compute( createMatMulNBitsProgramInfo(