From c18c465c0ae2e948736e720f2db859acac310739 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 25 Mar 2024 13:45:28 -0700 Subject: [PATCH] Changes to make any combinations of components to work. --- .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 8e630c7e2e19a..d0dd5f94755fa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -130,6 +130,15 @@ export const createMatMulNBitsProgramInfo = Array.from({length: 8}, (_, i) => `${dataType}((value >> ${(i * 4).toString()}) & 0xFu)`).join(', ')}); }`; + const updateZeroPointIndex = zeroPoints ? ` + zero_point_offset += 4; + if (zero_point_offset == 32) { + zero_point_offset = 0; + zero_point_index++; + zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; + }` : + ''; + return ` ${dequantizeImpl}; ${ortUnpack8x4snormImpl}; @@ -143,14 +152,13 @@ export const createMatMulNBitsProgramInfo = var row = ${outputNumber} * m; var col = ${components} * n; var a_indices: ${a.type.indices} = output_indices; - // Two zero points are packed into one byte because uniforms.bits <= 4. - // zero_point_offset is nibble offset within one word. - // TODO support zero_point_offset for bits > 4 + // Two zero points are packed into one byte when uniforms.bits is 4. ${ zeroPoints ? ` var zero_point_index: u32 = col * ((${nBlocksPerCol} + 1) / 2) / 4; var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; - var zero_point_offset: u32 = 0;` : + var zero_point_byte_offset: u32 = (col * ((${nBlocksPerCol} + 1) / 2)) % 4; + var zero_point_offset: u32 = 8 * zero_point_byte_offset;` : ''} var scale_index = col * ${nBlocksPerCol}; var b_indices: ${b.type.indices}; @@ -161,7 +169,7 @@ export const createMatMulNBitsProgramInfo = // The scale and zero points are computed per block. let scale = ${scales.getByOffset('scale_index')}; // The default zero point is 8 for unsigned 4-bit quantization. - let zero_point = ${dataType}(${zeroPoints ? 'zero_point_word & 0xFu' : 8.0}); + let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0}); ${b.indicesSet('b_indices', '1', 'block')}; var word_offset: u32 = block_offset; for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) { @@ -187,18 +195,15 @@ export const createMatMulNBitsProgramInfo = } } scale_index++; - ${ - zeroPoints ? `if ((scale_index + ${nBlocksPerCol % 2 ? 'c' : '0'}) % 8 == 0) { - zero_point_index++; - zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; - } else { - zero_point_word >>= 4; - }` : - ''} + ${updateZeroPointIndex} block_offset += uniforms.block_size / ${aComponents}; } // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte. - ${zeroPoints && (nBlocksPerCol % 2 !== 0) ? 'zero_point_word >>= 4;' : ''} + ${ + zeroPoints ? `if (zero_point_offset % 8 > 0) { + ${updateZeroPointIndex} + }` : + ''} } for (var k: u32 = 0u; k < ${outputNumber}u; k++) { ${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)};