Skip to content

Commit

Permalink
Changes to make any combinations of components to work.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Mar 25, 2024
1 parent 122753f commit c18c465
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ export const createMatMulNBitsProgramInfo =
Array.from({length: 8}, (_, i) => `${dataType}((value >> ${(i * 4).toString()}) & 0xFu)`).join(', ')});
}`;

const updateZeroPointIndex = zeroPoints ? `
zero_point_offset += 4;
if (zero_point_offset == 32) {
zero_point_offset = 0;
zero_point_index++;
zero_point_word = ${zeroPoints.getByOffset('zero_point_index')};
}` :
'';

return `
${dequantizeImpl};
${ortUnpack8x4snormImpl};
Expand All @@ -143,14 +152,13 @@ export const createMatMulNBitsProgramInfo =
var row = ${outputNumber} * m;
var col = ${components} * n;
var a_indices: ${a.type.indices} = output_indices;
// Two zero points are packed into one byte because uniforms.bits <= 4.
// zero_point_offset is nibble offset within one word.
// TODO support zero_point_offset for bits > 4
// Two zero points are packed into one byte when uniforms.bits is 4.
${
zeroPoints ? `
var zero_point_index: u32 = col * ((${nBlocksPerCol} + 1) / 2) / 4;
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
var zero_point_offset: u32 = 0;` :
var zero_point_byte_offset: u32 = (col * ((${nBlocksPerCol} + 1) / 2)) % 4;
var zero_point_offset: u32 = 8 * zero_point_byte_offset;` :
''}
var scale_index = col * ${nBlocksPerCol};
var b_indices: ${b.type.indices};
Expand All @@ -161,7 +169,7 @@ export const createMatMulNBitsProgramInfo =
// The scale and zero points are computed per block.
let scale = ${scales.getByOffset('scale_index')};
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${zeroPoints ? 'zero_point_word & 0xFu' : 8.0});
let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0});
${b.indicesSet('b_indices', '1', 'block')};
var word_offset: u32 = block_offset;
for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
Expand All @@ -187,18 +195,15 @@ export const createMatMulNBitsProgramInfo =
}
}
scale_index++;
${
zeroPoints ? `if ((scale_index + ${nBlocksPerCol % 2 ? 'c' : '0'}) % 8 == 0) {
zero_point_index++;
zero_point_word = ${zeroPoints.getByOffset('zero_point_index')};
} else {
zero_point_word >>= 4;
}` :
''}
${updateZeroPointIndex}
block_offset += uniforms.block_size / ${aComponents};
}
// Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte.
${zeroPoints && (nBlocksPerCol % 2 !== 0) ? 'zero_point_word >>= 4;' : ''}
${
zeroPoints ? `if (zero_point_offset % 8 > 0) {
${updateZeroPointIndex}
}` :
''}
}
for (var k: u32 = 0u; k < ${outputNumber}u; k++) {
${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)};
Expand Down

0 comments on commit c18c465

Please sign in to comment.