Skip to content

Commit

Permalink
run format
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Aug 15, 2024
1 parent a6707f2 commit 8d5c18e
Showing 1 changed file with 36 additions and 34 deletions.
70 changes: 36 additions & 34 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ export const createMatMulNBitsProgramInfo = (
};
};

// zeroPoints = null
export const createMatMulNBitsBlockwiseProgramInfo =
// zeroPoints = null
export const createMatMulNBitsBlockwiseProgramInfo =
(inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const aRank = inputShape.length;
Expand Down Expand Up @@ -394,37 +394,39 @@ export const createMatMulNBitsProgramInfo = (
for (let r = 0; r < outputNumber; r++) {
calcStr += `
input_offset = ${a.indicesToOffset(`${a.type.indices}(batch, row + ${r}, word_offset)`)};
${Array.from({length: 8 / aComponents}).map((_, j) =>
`
let a_data_${r}_${j} = ${a.getByOffset(`input_offset+${j}`)};`
).join('\n')}`;
${
Array.from({length: 8 / aComponents})
.map((_, j) => `
let a_data_${r}_${j} = ${a.getByOffset(`input_offset+${j}`)};`)
.join('\n')}`;
}
for (let c = 0; c < components; c++) {
calcStr += `
b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`};
b_value_lower = unpack4xU8(b_value & b_mask);
b_value_upper = unpack4xU8((b_value >> 4) & b_mask);
b_quantized_values = ${qDqDataType}(${
Array.from({length: 4}, (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`)
.join(', ')});
Array.from({length: 4}, (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`)
.join(', ')});
b_dequantized_values = ${(() => {
if (aComponents === 1) {
return `${qDqDataType}(${
Array.from({length: 8}, (_, i) => `(b_quantized_values[${i}] - zero_point) * scale${c}`).join(', ')});`;
} else {
return `(b_quantized_values - ${qDqDataType}(${Array(8).fill('zero_point').join(',')})) * scale${c};`;
}
})()};`;
if (aComponents === 1) {
return `${qDqDataType}(${
Array.from({length: 8}, (_, i) => `(b_quantized_values[${i}] - zero_point) * scale${c}`)
.join(', ')});`;
} else {
return `(b_quantized_values - ${qDqDataType}(${Array(8).fill('zero_point').join(',')})) * scale${c};`;
}
})()};`;
for (let r = 0; r < outputNumber; r++) {
calcStr += `
workgroup_shared[block * ${outputNumber} + ${r}]${components > 1 ? `[${c}]` : ''} += ${
Array
.from(
{length: 8 / aComponents},
(_, i) => `${
aComponents === 1 ? `a_data_${r}_${i} * b_dequantized_values[${i}]` :
`dot(a_data_${r}_${i}, b_dequantized_values[${i}])`}`)
.join(' + ')};
Array
.from(
{length: 8 / aComponents},
(_, i) => `${
aComponents === 1 ? `a_data_${r}_${i} * b_dequantized_values[${i}]` :
`dot(a_data_${r}_${i}, b_dequantized_values[${i}])`}`)
.join(' + ')};
`;
}
}
Expand All @@ -445,7 +447,7 @@ export const createMatMulNBitsProgramInfo = (
var b_value_lower: vec4<u32>;
var b_value_upper: vec4<u32>;
var b_quantized_values: ${qDqDataType};
var b_dequantized_values: ${qDqDataType};`
var b_dequantized_values: ${qDqDataType};`;
return calcStr;
};

Expand All @@ -454,24 +456,24 @@ export const createMatMulNBitsProgramInfo = (
let s_index0 = local_idx * ${outputNumber};
let s_index1 = (local_idx + interval) * ${outputNumber};`;
for (let r = 0; r < outputNumber; r++) {
str += `
str += `
workgroup_shared[s_index0 + ${r}] = workgroup_shared[s_index0 + ${r}] + workgroup_shared[s_index1 + ${r}];
`;
}
return str;
};
}
return str;
};

const setOutput = (): string => {
let str = `let globalRow = row * ${outputNumber};`;
for (let r = 0; r < outputNumber; r++) {
str += `
const setOutput = (): string => {
let str = `let globalRow = row * ${outputNumber};`;
for (let r = 0; r < outputNumber; r++) {
str += `
if (globalRow + ${r} < ${dimAOuter}) {
${output.setByIndices(`${output.type.indices}(batch, globalRow + ${r}, col)`, `workgroup_shared[${r}]`)};
}
`;
}
return str;
};
}
return str;
};

return `
var<workgroup> workgroup_shared: array<${output.type.value}, ${outputNumber * nBlocksPerCol}>;
Expand Down

0 comments on commit 8d5c18e

Please sign in to comment.