Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/WebGPU] Optimize MatMulNBits #19852

Merged
merged 12 commits into from
Mar 13, 2024
208 changes: 137 additions & 71 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';

// TODO support quantization bits not equal to 4
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
Expand Down Expand Up @@ -51,124 +51,190 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt

export const createMatMulNBitsProgramInfo =
(inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
const a = inputs[0];
const b = inputs[1];
const scales = inputs[2];
const aRank = a.dims.length;
const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n);
const outputSize = ShapeUtil.size(outputShape);


const inputShape = inputs[0].dims;
const aRank = inputShape.length;
const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n);
const m = inputShape[aRank - 2];
const blobSize = attributes.blockSize / 8 * attributes.bits;
const blobSizeInWords = blobSize / 4;
const outputNumber = getMaxComponents(m);
const components = getMaxComponents(attributes.n);
const aComponents = getMaxComponents(attributes.k);
const bComponents = getMaxComponents(blobSizeInWords);
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k},
{type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel},
{type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize}
];
programUniforms.push(...createTensorShapeVariables(a.dims));
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims)));
programUniforms.push(...createTensorShapeVariables(scales.dims));
const aShape = inputShape.slice();
aShape.splice(-1, 1, attributes.k / aComponents);
const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
bShape.splice(-1, 1, blobSizeInWords / bComponents);
programUniforms.push(...createTensorShapeVariables(aShape));
programUniforms.push(...createTensorShapeVariables(bShape));
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
if (inputs.length === 4) {
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
}
programUniforms.push(...createTensorShapeVariables(outputShape));
const oShape = outputShape.slice();
oShape.splice(-1, 1, attributes.n / components);
programUniforms.push(...createTensorShapeVariables(oShape));
const getShaderSource = (shaderHelper: ShaderHelper) => {
const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length);
const b = inputVariable('b', DataType.uint32, inputs[1].dims.length);
const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
const inputVariables = [a, b, scales];
const zeroPoints =
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
if (zeroPoints) {
inputVariables.push(zeroPoints);
}
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
const output = outputVariable('output', inputs[0].dataType, outputShape.length, components);
const uniforms: UniformsArrayType = [
{name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'},
{name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
{name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'}
];
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
const blobSize = attributes.blockSize / 8 * attributes.bits;
const wordPerBlob = blobSize / 4;
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
return `
fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{
var result = array<${dataType}, 8>();

const qDqDataType = (() => {
switch (aComponents) {
case 1:
return `array<${dataType}, 8>`;
case 2:
return `mat4x2<${dataType}>`;
case 4:
return `mat2x4<${dataType}>`;
default:
throw new Error(`${aComponents}-component is not supported.`);
}
})();

const dequantizeImpl = `
fn dequantize(quantized: ${qDqDataType}, zero_point: ${dataType}, scale: ${dataType}) -> ${qDqDataType} {
${(() => {
if (aComponents === 1) {
return `var dequantized = ${qDqDataType}(${
Array.from({length: 8}, (_, i) => `(quantized[${i}] - zero_point) * scale`).join(', ')});
return dequantized;`;
} else {
return `var zero_points: ${qDqDataType} = ${qDqDataType}(${Array(8).fill('zero_point').join(',')});
return (quantized - zero_points) * scale;`;
}
})()}
}`;
const ortUnpack8x4snormImpl = `
fn ortUnpack8x4snorm(value: u32) -> ${qDqDataType} {
var quantized: ${qDqDataType};
var offset: u32 = 0;
let count: u32 = 4;
for (var i: u32 = 0; i < 8u; i++) {
result[i] = ${dataType}(extractBits(value, offset, count));
var result = ${dataType}(extractBits(value, offset, count));
${(() => {
switch (aComponents) {
case 1:
return 'quantized[i] = result;';
case 2:
return 'quantized[i / 2][i % 2] = result;';
case 4:
return 'quantized[i / 4][i % 4] = result;';
default:
throw new Error(`${aComponents}-component is not supported.`);
}
})()}
offset += count;
}
return result;
}
return quantized;
}`;

const updateZeroPointIndex = zeroPoints ? `
zero_point_offset += 4;
if (zero_point_offset == 32) {
zero_point_offset = 0;
zero_point_index++;
zero_point_word = ${zeroPoints.getByOffset('zero_point_index')};
}` :
'';

return `
${dequantizeImpl};
${ortUnpack8x4snormImpl};
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
var value: ${dataType} = 0.0;
let output_indices = ${output.offsetToIndices('global_idx')};
var a_indices: ${a.type.indices} = output_indices;
var output_values: array<${output.type.value}, ${outputNumber}>;
var output_indices = ${output.offsetToIndices('global_idx')};
var n = ${output.indicesGet('output_indices', aRank - 1)};
var m = ${output.indicesGet('output_indices', aRank - 2)};
var a_indices: ${a.type.indices} = output_indices;
// Two zero points are packed into one byte because uniforms.bits <= 4.
// zero_point_offset is either 0 or 4. It is bit offset within one byte.
// TODO support zero_point_offset for bits > 4
${
zeroPoints ? `
var zero_point_index: u32 = n * ((${nBlocksPerCol} + 1) / 2) / 4;
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
var zero_point_offset: u32 = 0;` :
var zero_point_index: u32 = n * ${components} * ((${nBlocksPerCol} + 1) / 2) / 4;
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
var zero_point_offset: u32 = 0;` :
''}
var scale_idex = n * ${nBlocksPerCol};
var scale_index = n * ${nBlocksPerCol * components};
var b_indices: ${b.type.indices};
${b.indicesSet('b_indices', '0', 'n')};
var block_offset: u32 = 0;
for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) {
// The scale and zero points are computed per block.
let scale = ${scales.getByOffset('scale_idex')};
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point: ${dataType} = ${
zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0};
${b.indicesSet('b_indices', '1', 'block')};
var word_offset: u32 = block_offset;
for (var word: u32 = 0; word < ${wordPerBlob}; word++) {
${b.indicesSet('b_indices', '2', 'word')};
let b_value = ${b.getByIndices('b_indices')};
let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value);
// Number of B elements per 32-bit word is 32/bits = 32/4 = 8
var offset: u32 = word_offset;
for (var i: u32 = 0; i < 8; i++) {
${a.indicesSet('a_indices', aRank - 1, 'offset')};
let a_value = ${a.getByIndices('a_indices')};
let b_quantized_value = b_quantized_values[i];
let b_dequantized_value = (b_quantized_value - zero_point) * scale;
value += a_value * b_dequantized_value;
offset++;
for (var c: u32 = 0; c < ${components}; c++) {
${b.indicesSet('b_indices', '0', `n * ${components} + c`)};
var block_offset: u32 = 0;
for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) {
// The scale and zero points are computed per block.
let scale = ${scales.getByOffset('scale_index')};
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0});
${b.indicesSet('b_indices', '1', 'block')};
var word_offset: u32 = block_offset;
for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
${b.indicesSet('b_indices', '2', 'word')};
let b_data = ${b.getByIndices('b_indices')};
for (var i: u32 = 0; i < ${bComponents}; i++) {
let b_value = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'};
let b_quantized_values: ${qDqDataType} = ortUnpack8x4snorm(b_value);
let b_dequantized_values = dequantize(b_quantized_values, zero_point, scale);
// Number of B elements per 32-bit word is 32/bits = 32/4 = 8
var offset: u32 = word_offset;
for (var j: u32 = 0; j < 8/${aComponents}; j++) {
${a.indicesSet('a_indices', aRank - 1, `offset/${aComponents}`)};
for (var k: u32 = 0; k < ${outputNumber}u; k++) {
${a.indicesSet('a_indices', aRank - 2, `m * ${outputNumber} + k`)};
let a_data = ${a.getByIndices('a_indices')};
output_values[k]${components > 1 ? '[c]' : ''} += ${
aComponents === 1 ? 'a_data * b_dequantized_values[j]' : 'dot(a_data, b_dequantized_values[j])'};
}
offset += ${aComponents};
}
word_offset += 8;
}
}
word_offset += 8;
scale_index++;
${updateZeroPointIndex}
block_offset += uniforms.block_size;
}
scale_idex++;
// Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte.
${
zeroPoints ? `
if (zero_point_offset == 28) {
zero_point_offset = 0;
zero_point_index++;
zero_point_word = ${zeroPoints.getByOffset('zero_point_index')};
} else {
zero_point_offset += 4;
}` :
zeroPoints ? `if (zero_point_offset % 8 > 0) {
${updateZeroPointIndex}
}` :
''}
block_offset += uniforms.block_size;
}
${output.setByOffset('global_idx', 'value')};
}
`;
}
for (var k: u32 = 0u; k < ${outputNumber}u; k++) {
${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)};
${output.setByIndices('output_indices', 'output_values[k]')}
}
}`;
};
return {
name: 'MatMulNBits',
shaderCache:
{hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64)},
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource
Expand Down
57 changes: 57 additions & 0 deletions js/web/test/data/ops/matmulnbits.jsonc
Original file line number Diff line number Diff line change
@@ -1,4 +1,61 @@
[
{
"name": "MatMulNBits; K=16, N=16, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 16, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127
],
"dims": [8, 16],
"type": "float32"
},
{
"dims": [8, 1, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64
]
},
{
"dims": [8],
"type": "float32",
"data": [0, 1, 2, 3, 4, 5, 6, 7]
}
],
"outputs": [
{
"dims": [8, 8],
"type": "float32",
"data": [
0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0,
-1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735,
0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232,
-11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032,
-16405, -48288, -16247
]
}
]
}
]
},
{
"name": "MatMulNBits; K=16, N=16, block_size=16, bits=4",
"operator": "MatMulNBits",
Expand Down
Loading