From d38be7633d1c26467f6363b3a67a75ba9ac34e48 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 12 Dec 2024 18:24:40 +0800 Subject: [PATCH] Optimize matmulnbits with M > 1 --- .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 186 +++++++++++++++++- 1 file changed, 185 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 3e1f1be22efa2..5a6ddb3a1bed5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -434,9 +434,193 @@ export const createMatMulNBitsBlockSize32ProgramInfo = ( }; }; +// Currently, only support blockSize = 32. +export const createMatMulNBitsWithLargeMProgramInfo = ( + inputs: readonly TensorView[], + attributes: MatMulNBitsAttributes, +): ProgramInfo => { + const inputShape = inputs[0].dims; + const aRank = inputShape.length; + const dimAOuter = inputShape[aRank - 2]; + const dimInner = attributes.k; + const dimBOuter = attributes.n; + const batchDims = inputShape.slice(0, aRank - 2); + const batchSize = ShapeUtil.size(batchDims); + const blobSize = inputs[1].dims[2]; + const blobSizeInWords = blobSize / 4; + const dataType = inputs[0].dataType; + const aComponents = getMaxComponents(attributes.k); + const bComponents = getMaxComponents(blobSizeInWords); + const outputShape = batchDims.concat([dimAOuter, dimBOuter]); + + const workgroupSize = 64; + const tileM = 4; + const workgroupY = 8; + const workgroupX = workgroupSize / workgroupY; + const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data. + const aLengthPerTile = tileSize / aComponents; + const blocksPerTile = tileSize / attributes.blockSize; + + const programUniforms: ProgramUniform[] = []; + const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents]; + const bShape = ShapeUtil.convertShape(inputs[1].dims).slice(); + bShape.splice(-1, 1, blobSizeInWords / bComponents); + programUniforms.push(...createTensorShapeVariables(inputShapeTemp)); + programUniforms.push(...createTensorShapeVariables(bShape)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + if (inputs.length === 4) { + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); + } + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter]; + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const inputRank = inputShapeTemp.length; + const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents); + const b = inputVariable('b', DataType.uint32, bShape.length, bComponents); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const inputVariables = [a, b, scales]; + const zeroPoints = + inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; + if (zeroPoints) { + inputVariables.push(zeroPoints); + } + const outputRank = outputShapeTemp.length; + const output = outputVariable('output', inputs[0].dataType, outputRank); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const readA = () => { + switch (aComponents) { + case 1: + return ` + let a_data0 = vec4<${dataType}>(sub_a[r][word_offset], sub_a[r][word_offset + 1], sub_a[r][word_offset + 2], sub_a[r][word_offset + 3]); + let a_data1 = vec4<${dataType}>(sub_a[r][word_offset + 4], sub_a[r][word_offset + 5], sub_a[r][word_offset + 6], sub_a[r][word_offset + 7]);`; + case 2: + return ` + let a_data0 = vec4<${dataType}>(sub_a[r][word_offset], sub_a[r][word_offset + 1]); + let a_data1 = vec4<${dataType}>(sub_a[r][word_offset + 2], sub_a[r][word_offset + 3]);`; + case 4: + return ` + let a_data0 = sub_a[r][word_offset]; + let a_data1 = sub_a[r][word_offset + 1];`; + default: + throw new Error(`${aComponents}-component is not supported.`); + } + }; + + const loadTileA = () => { + let str = ''; + for (let i = 0; i < tileM; i++) { + str += `sub_a[${i}][a_offset] = mm_readA(batch, row + ${i}, a_col);`; + } + return str; + }; + return ` + fn mm_readA(batch: u32, row : u32, col : u32) -> ${a.type.value} { + if (row < uniforms.a_shape[1] && col < uniforms.a_shape[2]) + { + return ${a.getByIndices(`${a.type.indices}(batch, row, col)`)}; + } else { + return ${a.type.value}(0); + } + } + + var sub_a: array, ${tileM}>; + var inter_results: array, ${workgroupY}>, ${tileM}>; + ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart([workgroupX, workgroupY, 1])} + let col = workgroup_id.x * ${workgroupY}; + let row = workgroup_id.y * ${tileM}; + let batch = workgroup_id.z; + let n_blocks_per_col = uniforms.b_shape[1]; + let num_tiles = (n_blocks_per_col - 1) / ${blocksPerTile} + 1; + + // Loop over shared dimension. + for (var tile: u32 = 0; tile < num_tiles; tile += 1) { + let a_col_start = tile * ${aLengthPerTile}; + // load one tile A data into shared memory. + for (var a_offset = local_idx; a_offset < ${aLengthPerTile}; a_offset += ${workgroupSize}) + { + let a_col = a_col_start + a_offset; + ${loadTileA()} + } + workgroupBarrier(); + + // each thread process one block + let b_col = col + local_id.y; + let block = tile * ${blocksPerTile} + local_id.x; + ${ + zeroPoints + ? ` + let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2; + let zero_point_byte_count = b_col * zero_point_bytes_per_col + (block >> 0x1u); + let zero_point_word_index = zero_point_byte_count >> 0x2u; + let zero_point_byte_offset = zero_point_byte_count & 0x3u; + let zero_point_nibble_offset: u32 = block & 0x1u; + let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2); + let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset; + let zero_point = ${dataType}((zero_point_word) & 0xFu);` + : ` + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point = ${dataType}(${8.0});` + } + let scale = ${scales.getByOffset(`b_col * n_blocks_per_col + block`)}; + let b_data = ${b.getByIndices(`${b.type.indices}(b_col, block, 0)`)}; + var word_offset = local_id.x * ${attributes.blockSize / aComponents}; + for (var i: u32 = 0; i < ${bComponents}; i++) { + let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`}; + let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); + let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu); + let b_quantized_values = mat2x4<${dataType}>(${Array.from( + { length: 4 }, + (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`, + ).join(', ')}); + let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale; + for (var r = 0; r < ${tileM}; r++) { + ${readA()} + inter_results[r][local_id.y][local_id.x] += ${Array.from( + { length: 2 }, + (_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`, + ).join(' + ')}; + } + word_offset += ${8 / aComponents}; + } + workgroupBarrier(); + } + + if (local_id.y < ${tileM}) { + var output_value: ${output.type.value} = ${output.type.value}(0); + for (var b = 0u; b < ${workgroupX}; b++) { + output_value += inter_results[local_id.y][local_id.x][b]; + } + if (row + local_id.y < uniforms.output_shape[1] && col + local_id.x < uniforms.output_shape[2]) + { + ${output.setByIndices(`${output.type.indices}(batch, row + local_id.y, col + local_id.x)`, 'output_value')} + } + } + }`; + }; + return { + name: 'MatMulNBitsWithLargeM', + shaderCache: { + hint: `${attributes.blockSize};${aComponents};${bComponents};${workgroupX};${workgroupY};${tileM}`, + inputDependencies: Array(inputs.length).fill('rank'), + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType }], + dispatchGroup: { x: Math.ceil(dimBOuter / workgroupY), y: Math.ceil(dimAOuter / tileM), z: batchSize }, + programUniforms, + }), + getShaderSource, + }; +}; + export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { validateInputs(context.inputs, attributes); - if ( + const inputShape = context.inputs[0].dims; + const m = inputShape[inputShape.length - 2]; + if (m > 1 && attributes.blockSize === 32) { + context.compute(createMatMulNBitsWithLargeMProgramInfo(context.inputs, attributes)); + } else if ( attributes.blockSize === 32 && context.adapterInfo.isVendor('intel') && context.adapterInfo.isArchitecture('gen-12lp')