Skip to content

Commit

Permalink
Don't use workgroup memory for B
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Oct 12, 2024
1 parent 429961e commit 6f9845d
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
const outputShape = batchDims.concat([dimAOuter, dimBOuter]);

const workgroupSize = 64;
const workgroupY = 8;
const workgroupY = 4;
const workgroupX = workgroupSize / workgroupY;
const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data.
const aLengthPerTile = tileSize / aComponents;
Expand Down Expand Up @@ -363,8 +363,10 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
let calcStr = `var col_index = col * ${components};`;
for (let c = 0; c < components; c++) {
calcStr += `
let b_row = workgroup_id.x * ${workgroupY} + local_id.y;
let block = tile * ${blocksPerTile} + local_id.x;
let scale${c} = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
let b${c}_data = sub_b[local_id.y][local_id.x];
let b${c}_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
col_index += 1;`;
}
calcStr += `
Expand All @@ -378,7 +380,6 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
};
return `
var<workgroup> sub_a: array<${a.type.value}, ${aLengthPerTile}>;
var<workgroup> sub_b: array<array<${b.type.value}, ${workgroupX}>, ${workgroupY}>;
var<workgroup> inter_results: array<array<${output.type.value}, ${workgroupX}>, ${workgroupY}>;
${shaderHelper.declareVariables(...inputVariables, output)}
${shaderHelper.mainStart([workgroupX, workgroupY, 1])}
Expand All @@ -405,14 +406,6 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
sub_a[a_offset] = ${a.type.value}(0);
}
}
// load one tile B data into shared memory.
let b_row = workgroup_id.x * ${workgroupY} + local_id.y;
let block = tile * ${blocksPerTile} + local_id.x;
if (b_row < uniforms.b_shape[0] && block < uniforms.b_shape[1]) {
sub_b[local_id.y][local_id.x] = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
} else {
sub_b[local_id.y][local_id.x] = ${b.type.value}(0);
}
workgroupBarrier();
// each thread process one block
Expand Down

0 comments on commit 6f9845d

Please sign in to comment.