Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Oct 12, 2024
1 parent ed571b6 commit 89201c3
Showing 1 changed file with 55 additions and 66 deletions.
121 changes: 55 additions & 66 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ export const createMatMulNBitsProgramInfo = (
};
};

// TODO: support zeroPoints as input
// Currently, only support blockSize = 32.
export const createMatMulNBitsBlockSize32ProgramInfo = (
inputs: readonly TensorView[],
Expand All @@ -284,16 +283,15 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
const dataType = inputs[0].dataType;
const aComponents = getMaxComponents(attributes.k);
const bComponents = getMaxComponents(blobSizeInWords);
// const components = getMaxComponents(dimBOuter);
const components = 1;
const outputShape = batchDims.concat([dimAOuter, dimBOuter]);

const workgroupSize = 128;
const workgroupY = 8;
const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1;
const workgroupX = workgroupSize / workgroupY;
const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data.
const aLengthPerTile = tileSize / aComponents;
const blocksPerTile = tileSize / attributes.blockSize; // This requires tileSize must be larger than or equal to blockSize.
const blocksPerTile = tileSize / attributes.blockSize;
const dispatchSize = ShapeUtil.size(outputShape) / workgroupY;

const programUniforms: ProgramUniform[] = [];
const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents];
Expand All @@ -302,7 +300,10 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
programUniforms.push(...createTensorShapeVariables(inputShapeTemp));
programUniforms.push(...createTensorShapeVariables(bShape));
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
if (inputs.length === 4) {
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
}
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter];
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand All @@ -311,10 +312,15 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
const inputVariables = [a, b, scales];
const zeroPoints =
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
if (zeroPoints) {
inputVariables.push(zeroPoints);
}
const outputRank = outputShapeTemp.length;
const output = outputVariable('output', inputs[0].dataType, outputRank, components);
const output = outputVariable('output', inputs[0].dataType, outputRank);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const readA = (() => {
const readA = () => {
switch (aComponents) {
case 1:
return `
Expand All @@ -331,66 +337,19 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
default:
throw new Error(`${aComponents}-component is not supported.`);
}
});

const processOneWord = (): string => {
let calcStr = readA();
for (let c = 0; c < components; c++) {
calcStr += `
b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`};
b_value_lower = unpack4xU8(b_value & b_mask);
b_value_upper = unpack4xU8((b_value >> 4) & b_mask);
b_quantized_values = mat2x4<${dataType}>(${Array.from(
{ length: 4 },
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
).join(', ')});
b_dequantized_values = ${(() => {
return `(b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale${c};`;
})()};
inter_results[local_id.y][local_id.x] += ${Array.from(
{ length: 2 },
(_, i) =>
`${
`dot(a_data${i}, b_dequantized_values[${i}])`
}`,
).join(' + ')};
`;
}
return calcStr;
};

const prepareScaleAndBData = (): string => {
let calcStr = `var col_index = col * ${components};`;
for (let c = 0; c < components; c++) {
calcStr += `
let b_row = workgroup_id.x * ${workgroupY} + local_id.y;
let block = tile * ${blocksPerTile} + local_id.x;
let scale${c} = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
let b${c}_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
col_index += 1;`;
}
calcStr += `
var b_value: u32;
let b_mask: u32 = 0x0F0F0F0Fu;
var b_value_lower: vec4<u32>;
var b_value_upper: vec4<u32>;
var b_quantized_values:mat2x4<${dataType}>;
var b_dequantized_values: mat2x4<${dataType}>;`;
return calcStr;
};
return `
var<workgroup> sub_a: array<${a.type.value}, ${aLengthPerTile}>;
var<workgroup> inter_results: array<array<${output.type.value}, ${workgroupX}>, ${workgroupY}>;
${shaderHelper.declareVariables(...inputVariables, output)}
${shaderHelper.mainStart([workgroupX, workgroupY, 1])}
let col = workgroup_id.x * ${workgroupY} + local_idx;
let row = workgroup_id.y;
let batch = workgroup_id.z;
let output_indices = ${output.offsetToIndices(`workgroup_index * ${workgroupY}`)};
let col = output_indices[2];
let row = output_indices[1];
let batch = output_indices[0];
let n_blocks_per_col = uniforms.b_shape[1];
let num_tiles = (n_blocks_per_col - 1) / ${blocksPerTile} + 1;
let blob_size_in_words = uniforms.b_shape[2];
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${8.0});
// Loop over shared dimension.
for (var tile: u32 = 0; tile < num_tiles; tile += 1) {
Expand All @@ -409,10 +368,40 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
workgroupBarrier();
// each thread process one block
let b_row = col + local_id.y;
let block = tile * ${blocksPerTile} + local_id.x;
${
zeroPoints
? `
let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);
let zero_point_word_index = zero_point_byte_count >> 0x2u;
let zero_point_byte_offset = zero_point_byte_count & 0x3u;
let zero_point_nibble_offset: u32 = block & 0x1u;
let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
let zero_point = ${dataType}((zero_point_word) & 0xFu);`
: `
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${8.0});`
}
let scale = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
let b_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
var word_offset = local_id.x * ${attributes.blockSize / aComponents};
${prepareScaleAndBData()}
for (var i: u32 = 0; i < ${bComponents}; i++) {
${processOneWord()}
${readA()}
let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`};
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
let b_quantized_values = mat2x4<${dataType}>(${Array.from(
{ length: 4 },
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
).join(', ')});
let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
inter_results[local_id.y][local_id.x] += ${Array.from(
{ length: 2 },
(_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
).join(' + ')};
word_offset += ${8 / aComponents};
}
workgroupBarrier();
Expand All @@ -423,22 +412,22 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
for (var b = 0u; b < ${workgroupX}; b++) {
output_value += inter_results[local_idx][b];
}
if (col < uniforms.output_shape[2])
if (col + local_idx < uniforms.output_shape[2])
{
${output.setByIndices(`${output.type.indices}(batch, row, col)`, 'output_value')}
${output.setByIndices(`${output.type.indices}(batch, row, col + local_idx)`, 'output_value')}
}
}
}`;
};
return {
name: 'BlockwiseMatMulNBits32',
shaderCache: {
hint: `${attributes.blockSize};${attributes.bits};${aComponents};${bComponents};${components}`,
hint: `${attributes.blockSize};${aComponents};${bComponents};${workgroupX};${workgroupY}`,
inputDependencies: Array(inputs.length).fill('rank'),
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType }],
dispatchGroup: { x: Math.ceil(dimBOuter / components / workgroupY), y: dimAOuter, z: batchSize },
dispatchGroup: { x: dispatchSize },
programUniforms,
}),
getShaderSource,
Expand All @@ -447,7 +436,7 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (

export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
validateInputs(context.inputs, attributes);
if(context.inputs.length < 4 && attributes.blockSize == 32 && context.adapterInfo.isVendor("intel")) {
if (attributes.blockSize === 32 && context.adapterInfo.isVendor('intel')) {
context.compute(createMatMulNBitsBlockSize32ProgramInfo(context.inputs, attributes));
} else {
context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
Expand Down

0 comments on commit 89201c3

Please sign in to comment.