Skip to content

Commit

Permalink
add limitations
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Aug 19, 2024
1 parent e007210 commit 2bf70ef
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,13 @@ export const createMatMulNBitsBlockwiseProgramInfo = (
): ProgramInfo => {
const inputShape = inputs[0].dims;
const aRank = inputShape.length;
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
const nBlocksPerCol = inputs[1].dims[1];
const dimAOuter = inputShape[aRank - 2];
const dimInner = attributes.k;
const dimBOuter = attributes.n;
const batchDims = inputShape.slice(0, aRank - 2);
const batchSize = ShapeUtil.size(batchDims);
const blobSize = (attributes.blockSize / 8) * attributes.bits;
const blobSize = inputs[1].dims[2];
const blobSizeInWords = blobSize / 4;
const dataType = inputs[0].dataType;
const aComponents = getMaxComponents(attributes.k);
Expand Down Expand Up @@ -498,10 +498,11 @@ export const createMatMulNBitsBlockwiseProgramInfo = (

export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
validateInputs(context.inputs, attributes);
if (context.inputs.length < 4) {
const nBlocksPerCol = context.inputs[1].dims[1];
const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes();
if (context.inputs.length < 4 && nBlocksPerCol < maxComputeWorkgroupSizes[0]) {
context.compute(createMatMulNBitsBlockwiseProgramInfo(context.inputs, attributes));
} else {
const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes();
const maxComputeWorkgroupStorageSize = context.getMaxComputeWorkgroupStoragesize();
context.compute(
createMatMulNBitsProgramInfo(
Expand Down

0 comments on commit 2bf70ef

Please sign in to comment.