Skip to content

Commit

Permalink
Improving 0
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangzhaoming committed Dec 19, 2024
1 parent 71dc1d0 commit d390058
Showing 1 changed file with 57 additions and 30 deletions.
87 changes: 57 additions & 30 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
* workgroupY is a divisor of dimBOuter (i.e. N), could be 1, 4, or 8
* Workgroup cached A is shared within the same workgroupY
*/
const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1;
const workgroupY = dimBOuter % 16 === 0 ? 16 : dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1;
const workgroupX = workgroupSize / workgroupY;
/**
* tileSize = workgroupX * bComponents * 8
Expand Down Expand Up @@ -515,17 +515,14 @@ export const createMatMulNBitsSubgroupsBlockSize32NProgramInfo = (
/** Total subgroup-cached input A size, in scalar */
const scalarInputAPerSubgroupCache = vec4InputAPerSubgroupCache * 4;
// /** Total subgroup-cached input A size, in A loading point */
// const loadingPointAPerSubgroupCache = loadingPointAPerSubgroupThread * subgroupSize;
const loadingPointAPerSubgroupCache = loadingPointAPerSubgroupThread * subgroupSize;
if (
scalarInputAPerSubgroupCache % blobSizeInScalarB !== 0 &&
blobSizeInScalarB % scalarInputAPerSubgroupCache !== 0
) {
throw new Error(`One of scalarInputAPerSubgroupCache (${scalarInputAPerSubgroupCache}) and \
blobSizeInScalarB (${blobSizeInScalarB}) should divide another.`);
}
// const blockBPerSubgroupCacheA = scalarInputAPerSubgroupCache / blobSizeInScalarB;
// const loadingPointBPerSubgroupCacheA = scalarInputAPerSubgroupCache / scalarsPerLoadingPointB;
// const subgroupCacheAPerBlockB = blobSizeInScalarB / scalarInputAPerSubgroupCache;

// Each workgroup computes subgroupSize scalars of output
const scalarOutputPerWorkgroup = subgroupSize;
Expand Down Expand Up @@ -627,12 +624,19 @@ blobSizeInScalarB (${blobSizeInScalarB}) should divide another.`);
(_, loadingPoint) => `
{
let loading_point_A_col = loading_point_A_subgroup_cache_start_col + ${loadingPoint} * ${subgroupSize} + subgroup_id;
if (loading_point_A_col < ${a.shape}[2]) {
subgroup_cached_loading_points_A[${loadingPoint}] = ${a.getByIndices(`${a.type.indices}(batch, output_row_workgroup_base, loading_point_A_col)`)};
} else {
subgroup_cached_loading_points_A[${loadingPoint}] = ${a.type.value}();
}
/*
subgroup_cached_loading_points_A[${loadingPoint}] =
select(
${a.type.value}(0),
${a.getByIndices(`${a.type.indices}(batch, output_row_workgroup_base, loading_point_A_col)`)},
loading_point_A_col < ${a.shape}[2]
);
*/
}`,
)
.join('')}
Expand Down Expand Up @@ -683,7 +687,7 @@ blobSizeInScalarB (${blobSizeInScalarB}) should divide another.`);
}`
: ''
}
acc += dot(${GetVec4AFromSubgroupCache(loadingPointASubgroupCacheStart)}, b_dequantized_values[${dequantizedMatrixRow}]);
inter_results[subgroup_id][k_split_id] += dot(${GetVec4AFromSubgroupCache(loadingPointASubgroupCacheStart)}, b_dequantized_values[${dequantizedMatrixRow}]);
}`;
// End of computation step body
})
Expand All @@ -694,7 +698,7 @@ blobSizeInScalarB (${blobSizeInScalarB}) should divide another.`);
// Update subgroup cache position on the end if next substep compute new subgroup cache
${
(substep + 1) % substepsPerSubgroupCacheAUpdate === 0 && substep + 1 < loopSubsteps
? 'loading_point_A_subgroup_cache_start_col += loadingPointAPerSubgroupCache;'
? 'loading_point_A_subgroup_cache_start_col += LoadingPointAPerSubgroupCache;'
: ''
}
}`;
Expand All @@ -703,11 +707,10 @@ blobSizeInScalarB (${blobSizeInScalarB}) should divide another.`);
return `
const BlockBPerLoop = ${blockBPerLoop}u;
const LoadingPointAPerLoop = ${loadingPointAPerLoop}u;
const LoadingPointAPerSubgroupCache = ${loadingPointAPerSubgroupCache}u;
const KSplitFactor = ${kSplitFactor}u;
var<workgroup> inter_results: array<array<${output.type.value}, ${kSplitFactor}>, ${subgroupSize}>;
var subgroup_cached_loading_points_A: array<${a.type.value}, ${loadingPointAPerSubgroupThread}>;
var acc: ${output.type.value};
${shaderHelper.declareVariables(...inputVariables, output)}
Expand Down Expand Up @@ -737,32 +740,47 @@ fn main(
let input_b_col = output_col_workgroup_base + subgroup_id;
var subgroup_cached_loading_points_A: array<${a.type.value}, ${loadingPointAPerSubgroupThread}>;
// var acc: ${output.type.value};
// Blocks are splited in an interleaved manner for k-split.
var block: u32;
// block = BlockBPerLoop * k_split_id + loop_K * BlockBPerLoop * KSplitFactor + 0..(BlockBPerLoop-1)
var block: u32 = BlockBPerLoop * k_split_id;
var scale: ${scales.type.value};
// The default zero point is 8 for unsigned 4-bit quantization.
var zero_point = ${dataType}(${8.0});
${zeroPoints?
`var zero_point: ${dataType};`:
`// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${8.0});`}
var loading_point_b: ${b.type.value};
var b_dequantized_values: mat2x4<${dataType}>;
// Subgroup cache A position
var loading_point_A_subgroup_cache_start_col: u32;
// loading_point_A_subgroup_cache_start_col =
// LoadingPointAPerLoop * k_split_id +
// loop_K * LoadingPointAPerLoop * KSplitFactor +
// 0:loadingPointAPerSubgroupCache:LoadingPointAPerLoop-LoadingPointAPerSubgroupCache
var loading_point_A_subgroup_cache_start_col: u32 = LoadingPointAPerLoop * k_split_id;
for (var loop_K: u32 = 0; loop_K < loop_K_steps; loop_K++) {
let split_mapped_loop_K_id = (loop_K * KSplitFactor + k_split_id);
block = BlockBPerLoop * split_mapped_loop_K_id;
loading_point_A_subgroup_cache_start_col = LoadingPointAPerLoop * split_mapped_loop_K_id;
// let split_mapped_loop_K_id = (loop_K * KSplitFactor + k_split_id);
// block = BlockBPerLoop * split_mapped_loop_K_id;
// loading_point_A_subgroup_cache_start_col = LoadingPointAPerLoop * split_mapped_loop_K_id;
${Array.from({ length: loopSubsteps })
.map((_, subStep) => loopSubStep(subStep))
.join('\n')}
// block += BlockBPerLoop * KSplitFactor - (BlockBPerLoop - 1)
block += 1 + BlockBPerLoop * (KSplitFactor - 1);
// loading_point_A_subgroup_cache_start_col +=
// LoadingPointAPerLoop * KSplitFactor - (LoadingPointAPerLoop - loadingPointAPerSubgroupCache)
loading_point_A_subgroup_cache_start_col += LoadingPointAPerSubgroupCache + LoadingPointAPerLoop * (KSplitFactor - 1);
}
inter_results[subgroup_id][k_split_id] = acc;
// inter_results[subgroup_id][k_split_id] = acc;
workgroupBarrier();
if (local_idx < ${subgroupSize}) {
var output_value: ${output.type.value} = ${output.type.value}(0);
for (var b = 0u; b < ${kSplitFactor}; b++) {
output_value += inter_results[subgroup_id][b];
}
${Array.from({length: kSplitFactor}).map((_, kSplit) => `
output_value += inter_results[subgroup_id][${kSplit}];`).join('')}
if (output_col_workgroup_base + local_idx < uniforms.output_shape[2])
{
${output.setByIndices(`${output.type.indices}(batch, output_row_workgroup_base, output_col_workgroup_base + local_idx)`, 'output_value')}
Expand All @@ -771,7 +789,7 @@ fn main(
}`;
};
return {
name: 'BlockwiseMatMulNBits32',
name: 'BlockwiseMatMulNBitsSubgroups32N',
shaderCache: {
hint: `${attributes.blockSize};${aComponents};${bComponents};${subgroupSize};${kSplitFactor}`,
inputDependencies: Array(inputs.length).fill('rank'),
Expand All @@ -787,23 +805,32 @@ fn main(

export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
validateInputs(context.inputs, attributes);
// if (
// attributes.blockSize >= 32 &&
// context.adapterInfo.isVendor('intel') &&
// context.adapterInfo.isArchitecture('gen-12lp')
// ) {
// context.compute(
// createMatMulNBitsSubgroupsBlockSize32NProgramInfo(context.inputs, attributes, {
// subgroupSize: 16,
// kSplitFactor: 4,
// }),
// );
// } else
if (
attributes.blockSize >= 32 &&
attributes.blockSize === 32 &&
context.adapterInfo.isVendor('intel') &&
context.adapterInfo.isArchitecture('gen-12lp')
true
// context.adapterInfo.isArchitecture('gen-12lp')
) {
context.compute(
createMatMulNBitsSubgroupsBlockSize32NProgramInfo(context.inputs, attributes, {
subgroupSize: 16,
kSplitFactor: 16,
kSplitFactor: 8,
}),
);
} else if (
attributes.blockSize === 32 &&
context.adapterInfo.isVendor('intel') &&
context.adapterInfo.isArchitecture('gen-12lp')
) {
context.compute(createMatMulNBitsBlockSize32ProgramInfo(context.inputs, attributes));
// context.compute(createMatMulNBitsBlockSize32ProgramInfo(context.inputs, attributes));
// context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
} else {
context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
}
Expand Down

0 comments on commit d390058

Please sign in to comment.