diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 5edfacd207c94..50ce7755d4d9e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -293,7 +293,7 @@ export const createMatMulNBitsBlockSize32ProgramInfo = ( * workgroupY is a divisor of dimBOuter (i.e. N), could be 1, 4, or 8 * Workgroup cached A is shared within the same workgroupY */ - const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1; + const workgroupY = dimBOuter % 16 === 0 ? 16 : dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1; const workgroupX = workgroupSize / workgroupY; /** * tileSize = workgroupX * bComponents * 8 @@ -515,7 +515,7 @@ export const createMatMulNBitsSubgroupsBlockSize32NProgramInfo = ( /** Total subgroup-cached input A size, in scalar */ const scalarInputAPerSubgroupCache = vec4InputAPerSubgroupCache * 4; // /** Total subgroup-cached input A size, in A loading point */ - // const loadingPointAPerSubgroupCache = loadingPointAPerSubgroupThread * subgroupSize; + const loadingPointAPerSubgroupCache = loadingPointAPerSubgroupThread * subgroupSize; if ( scalarInputAPerSubgroupCache % blobSizeInScalarB !== 0 && blobSizeInScalarB % scalarInputAPerSubgroupCache !== 0 @@ -523,9 +523,6 @@ export const createMatMulNBitsSubgroupsBlockSize32NProgramInfo = ( throw new Error(`One of scalarInputAPerSubgroupCache (${scalarInputAPerSubgroupCache}) and \ blobSizeInScalarB (${blobSizeInScalarB}) should divide another.`); } - // const blockBPerSubgroupCacheA = scalarInputAPerSubgroupCache / blobSizeInScalarB; - // const loadingPointBPerSubgroupCacheA = scalarInputAPerSubgroupCache / scalarsPerLoadingPointB; - // const subgroupCacheAPerBlockB = blobSizeInScalarB / scalarInputAPerSubgroupCache; // Each workgroup computes subgroupSize scalars of output const scalarOutputPerWorkgroup = subgroupSize; @@ -627,12 +624,19 @@ blobSizeInScalarB (${blobSizeInScalarB}) should divide another.`); (_, loadingPoint) => ` { let loading_point_A_col = loading_point_A_subgroup_cache_start_col + ${loadingPoint} * ${subgroupSize} + subgroup_id; + if (loading_point_A_col < ${a.shape}[2]) { + subgroup_cached_loading_points_A[${loadingPoint}] = ${a.getByIndices(`${a.type.indices}(batch, output_row_workgroup_base, loading_point_A_col)`)}; + } else { + subgroup_cached_loading_points_A[${loadingPoint}] = ${a.type.value}(); + } + /* subgroup_cached_loading_points_A[${loadingPoint}] = select( ${a.type.value}(0), ${a.getByIndices(`${a.type.indices}(batch, output_row_workgroup_base, loading_point_A_col)`)}, loading_point_A_col < ${a.shape}[2] ); + */ }`, ) .join('')} @@ -683,7 +687,7 @@ blobSizeInScalarB (${blobSizeInScalarB}) should divide another.`); }` : '' } - acc += dot(${GetVec4AFromSubgroupCache(loadingPointASubgroupCacheStart)}, b_dequantized_values[${dequantizedMatrixRow}]); + inter_results[subgroup_id][k_split_id] += dot(${GetVec4AFromSubgroupCache(loadingPointASubgroupCacheStart)}, b_dequantized_values[${dequantizedMatrixRow}]); }`; // End of computation step body }) @@ -694,7 +698,7 @@ blobSizeInScalarB (${blobSizeInScalarB}) should divide another.`); // Update subgroup cache position on the end if next substep compute new subgroup cache ${ (substep + 1) % substepsPerSubgroupCacheAUpdate === 0 && substep + 1 < loopSubsteps - ? 'loading_point_A_subgroup_cache_start_col += loadingPointAPerSubgroupCache;' + ? 'loading_point_A_subgroup_cache_start_col += LoadingPointAPerSubgroupCache;' : '' } }`; @@ -703,11 +707,10 @@ blobSizeInScalarB (${blobSizeInScalarB}) should divide another.`); return ` const BlockBPerLoop = ${blockBPerLoop}u; const LoadingPointAPerLoop = ${loadingPointAPerLoop}u; +const LoadingPointAPerSubgroupCache = ${loadingPointAPerSubgroupCache}u; const KSplitFactor = ${kSplitFactor}u; var inter_results: array, ${subgroupSize}>; -var subgroup_cached_loading_points_A: array<${a.type.value}, ${loadingPointAPerSubgroupThread}>; -var acc: ${output.type.value}; ${shaderHelper.declareVariables(...inputVariables, output)} @@ -737,32 +740,47 @@ fn main( let input_b_col = output_col_workgroup_base + subgroup_id; + var subgroup_cached_loading_points_A: array<${a.type.value}, ${loadingPointAPerSubgroupThread}>; + // var acc: ${output.type.value}; + // Blocks are splited in an interleaved manner for k-split. - var block: u32; + // block = BlockBPerLoop * k_split_id + loop_K * BlockBPerLoop * KSplitFactor + 0..(BlockBPerLoop-1) + var block: u32 = BlockBPerLoop * k_split_id; var scale: ${scales.type.value}; - // The default zero point is 8 for unsigned 4-bit quantization. - var zero_point = ${dataType}(${8.0}); + ${zeroPoints? + `var zero_point: ${dataType};`: + `// The default zero point is 8 for unsigned 4-bit quantization. + let zero_point = ${dataType}(${8.0});`} var loading_point_b: ${b.type.value}; var b_dequantized_values: mat2x4<${dataType}>; // Subgroup cache A position - var loading_point_A_subgroup_cache_start_col: u32; + // loading_point_A_subgroup_cache_start_col = + // LoadingPointAPerLoop * k_split_id + + // loop_K * LoadingPointAPerLoop * KSplitFactor + + // 0:loadingPointAPerSubgroupCache:LoadingPointAPerLoop-LoadingPointAPerSubgroupCache + var loading_point_A_subgroup_cache_start_col: u32 = LoadingPointAPerLoop * k_split_id; for (var loop_K: u32 = 0; loop_K < loop_K_steps; loop_K++) { - let split_mapped_loop_K_id = (loop_K * KSplitFactor + k_split_id); - block = BlockBPerLoop * split_mapped_loop_K_id; - loading_point_A_subgroup_cache_start_col = LoadingPointAPerLoop * split_mapped_loop_K_id; + // let split_mapped_loop_K_id = (loop_K * KSplitFactor + k_split_id); + // block = BlockBPerLoop * split_mapped_loop_K_id; + // loading_point_A_subgroup_cache_start_col = LoadingPointAPerLoop * split_mapped_loop_K_id; ${Array.from({ length: loopSubsteps }) .map((_, subStep) => loopSubStep(subStep)) .join('\n')} + + // block += BlockBPerLoop * KSplitFactor - (BlockBPerLoop - 1) + block += 1 + BlockBPerLoop * (KSplitFactor - 1); + // loading_point_A_subgroup_cache_start_col += + // LoadingPointAPerLoop * KSplitFactor - (LoadingPointAPerLoop - loadingPointAPerSubgroupCache) + loading_point_A_subgroup_cache_start_col += LoadingPointAPerSubgroupCache + LoadingPointAPerLoop * (KSplitFactor - 1); } - inter_results[subgroup_id][k_split_id] = acc; + // inter_results[subgroup_id][k_split_id] = acc; workgroupBarrier(); if (local_idx < ${subgroupSize}) { var output_value: ${output.type.value} = ${output.type.value}(0); - for (var b = 0u; b < ${kSplitFactor}; b++) { - output_value += inter_results[subgroup_id][b]; - } + ${Array.from({length: kSplitFactor}).map((_, kSplit) => ` + output_value += inter_results[subgroup_id][${kSplit}];`).join('')} if (output_col_workgroup_base + local_idx < uniforms.output_shape[2]) { ${output.setByIndices(`${output.type.indices}(batch, output_row_workgroup_base, output_col_workgroup_base + local_idx)`, 'output_value')} @@ -771,7 +789,7 @@ fn main( }`; }; return { - name: 'BlockwiseMatMulNBits32', + name: 'BlockwiseMatMulNBitsSubgroups32N', shaderCache: { hint: `${attributes.blockSize};${aComponents};${bComponents};${subgroupSize};${kSplitFactor}`, inputDependencies: Array(inputs.length).fill('rank'), @@ -787,23 +805,32 @@ fn main( export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { validateInputs(context.inputs, attributes); + // if ( + // attributes.blockSize >= 32 && + // context.adapterInfo.isVendor('intel') && + // context.adapterInfo.isArchitecture('gen-12lp') + // ) { + // context.compute( + // createMatMulNBitsSubgroupsBlockSize32NProgramInfo(context.inputs, attributes, { + // subgroupSize: 16, + // kSplitFactor: 4, + // }), + // ); + // } else if ( - attributes.blockSize >= 32 && + attributes.blockSize === 32 && context.adapterInfo.isVendor('intel') && - context.adapterInfo.isArchitecture('gen-12lp') + true + // context.adapterInfo.isArchitecture('gen-12lp') ) { context.compute( createMatMulNBitsSubgroupsBlockSize32NProgramInfo(context.inputs, attributes, { subgroupSize: 16, - kSplitFactor: 16, + kSplitFactor: 8, }), ); - } else if ( - attributes.blockSize === 32 && - context.adapterInfo.isVendor('intel') && - context.adapterInfo.isArchitecture('gen-12lp') - ) { - context.compute(createMatMulNBitsBlockSize32ProgramInfo(context.inputs, attributes)); + // context.compute(createMatMulNBitsBlockSize32ProgramInfo(context.inputs, attributes)); + // context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); } else { context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); }