Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Aug 16, 2024
1 parent 6bd5417 commit d28ac0a
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ export const createMatMulNBitsProgramInfo = (
};
};

// zeroPoints = null
// TODO: support zeroPoints as input
export const createMatMulNBitsBlockwiseProgramInfo = (
inputs: readonly TensorView[],
attributes: MatMulNBitsAttributes,
Expand Down Expand Up @@ -364,11 +364,6 @@ export const createMatMulNBitsBlockwiseProgramInfo = (
const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
const inputVariables = [a, b, scales];
const zeroPoints =
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
if (zeroPoints) {
inputVariables.push(zeroPoints);
}
const outputRank = outputShapeTemp.length;
const output = outputVariable('output', inputs[0].dataType, outputRank, components);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
Expand Down Expand Up @@ -456,8 +451,8 @@ export const createMatMulNBitsBlockwiseProgramInfo = (
var row = workgroup_id.y;
var batch = workgroup_id.z;
// Two zero points are packed into one byte when uniforms.bits is 4.
let zero_point = ${dataType}(${zeroPoints ? '(zero_point_word) & 0xFu' : 8.0});
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${8.0});
var word_offset: u32 = block * ${attributes.blockSize / aComponents};
//process one block
Expand Down Expand Up @@ -488,7 +483,7 @@ export const createMatMulNBitsBlockwiseProgramInfo = (
return {
name: 'BlockwiseMatMulNBitsV1',
shaderCache: {
hint: `${attributes.cacheKey};${dataType};${outputNumber}`,
hint: `${attributes.blockSize};${dataType};${outputNumber};${nBlocksPerCol}`,
inputDependencies: Array(inputs.length).fill('rank'),
},
getRunData: () => ({
Expand Down

0 comments on commit d28ac0a

Please sign in to comment.