Skip to content

Commit

Permalink
Indices should be normalized before indexing. Added a test case.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Aug 17, 2024
1 parent 0a7387d commit 9b5eac4
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 21 deletions.
45 changes: 24 additions & 21 deletions js/web/lib/wasm/jsep/webgpu/ops/gather-block-quantized.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ const createGatherBlockQuantizedProgramInfo = (
return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
var output_indices = ${output.offsetToIndices('global_idx')};
let output_indices = ${output.offsetToIndices('global_idx')};
var indices_indices = ${indices.type.indices}(0);
${(() => {
if (indicesShape.length > 1) {
return `
for (var i: u32 = 0; i < ${indicesShape.length}; i++) {
var index = ${output.indicesGet('output_indices', 'uniforms.gather_axis + i')};
let index = ${output.indicesGet('output_indices', 'uniforms.gather_axis + i')};
${indices.indicesSet('indices_indices', 'i', 'index')};
}`;
} else {
Expand All @@ -121,41 +121,44 @@ const createGatherBlockQuantizedProgramInfo = (
})()};
var data_indices = ${data.type.indices}(0);
for (var i: u32 = 0; i < uniforms.gather_axis; i++) {
var index = ${output.indicesGet('output_indices', 'i')};
let index = ${output.indicesGet('output_indices', 'i')};
${data.indicesSet('data_indices', 'i', 'index')};
}
var index_from_indices = u32(${indices.getByIndices('indices_indices')});
${data.indicesSet('data_indices', 'uniforms.gather_axis', 'index_from_indices')};
var index_from_indices = ${indices.getByIndices('indices_indices')};
if (index_from_indices < 0) {
index_from_indices += ${inputShape[gatherAxis]};
}
${data.indicesSet('data_indices', 'uniforms.gather_axis', 'u32(index_from_indices)')};
for (var i = uniforms.gather_axis + 1; i < ${outputShape.length}; i++) {
var index = ${output.indicesGet('output_indices', `i + ${indicesShape.length} - 1`)};
let index = ${output.indicesGet('output_indices', `i + ${indicesShape.length} - 1`)};
${data.indicesSet('data_indices', 'i', 'index')};
}
var data_offset = ${data.indicesToOffset('data_indices')};
var data_index = data_offset % 8;
let data_offset = ${data.indicesToOffset('data_indices')};
let data_index = data_offset % 8;
// Convert 4-bit packed data to 8-bit packed data.
var packed_4bit_quantized_data = ${data.getByOffset('data_offset / 8')};
var packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f;
var quantized_data_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_quantized_data));
var quantized_data = quantized_data_vec[data_index / 2];
let packed_4bit_quantized_data = ${data.getByOffset('data_offset / 8')};
let packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f;
let quantized_data_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_quantized_data));
let quantized_data = quantized_data_vec[data_index / 2];
var scale_indices = data_indices;
var quantize_axis_index = ${scales.indicesGet('data_indices', 'uniforms.quantize_axis')} / uniforms.block_size;
let quantize_axis_index = ${scales.indicesGet('data_indices', 'uniforms.quantize_axis')} / uniforms.block_size;
${scales.indicesSet('scale_indices', 'uniforms.quantize_axis', 'quantize_axis_index')};
var scale = ${scales.getByIndices('scale_indices')};
${(() => {
if (!zeroPoint) {
return 'var zero_point = 0';
} else {
return `
var zero_point_indices = scale_indices;
var zero_point_offset = ${zeroPoint.indicesToOffset('zero_point_indices')};
var zero_point_index = zero_point_offset % 8;
var packed_4bit_zero_points = ${zeroPoint.getByOffset('zero_point_offset / 8')};
var packed_8bit_zero_point = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f;
var zero_point_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_zero_point));
var zero_point = zero_point_vec[zero_point_index / 2];`;
let zero_point_indices = scale_indices;
let zero_point_offset = ${zeroPoint.indicesToOffset('zero_point_indices')};
let zero_point_index = zero_point_offset % 8;
let packed_4bit_zero_points = ${zeroPoint.getByOffset('zero_point_offset / 8')};
let packed_8bit_zero_points = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f;
let zero_point_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_zero_points));
let zero_point = zero_point_vec[zero_point_index / 2];`;
}
})()};
var dequantized_data = ${tensorTypeToWsglValueType(outputType)}(quantized_data - zero_point) * scale;
let dequantized_data = ${tensorTypeToWsglValueType(outputType)}(quantized_data - zero_point) * scale;
${output.setByOffset('global_idx', 'dequantized_data')};
}`;
};
Expand Down
42 changes: 42 additions & 0 deletions js/web/test/data/ops/gather-block-quantized.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,48 @@
"type": "float32"
}
]
},
{
"name": "GatherBlockQuantized; quantize_axis=0, gather_axis=1, signed block_size=16, signed input, negative indices",
"inputs": [
// data
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 7, 0, 1, 2, 3, 4, 5, 6, 6, 7, 0, 1, 2, 3, 4, 5, 5, 6, 7, 0, 1, 2, 3, 4, 4, 5, 6,
7, 0, 1, 2, 3, 3, 4, 5, 6, 7, 0, 1, 2, 2, 3, 4, 5, 6, 7, 0, 1, 1, 2, 3, 4, 5, 6, 7, 0, 0, 1, 2, 3, 4, 5,
6, 7, 7, 0, 1, 2, 3, 4, 5, 6, 6, 7, 0, 1, 2, 3, 4, 5, 5, 6, 7, 0, 1, 2, 3, 4
],
"dims": [2, 3, 16],
"type": "int4"
},
// indices
{
"data": [-1],
"dims": [1],
"type": "int32"
},
// scale
{
"data": [0.5, 1.0, 1.25, 1.5, 1.75, 2.0],
"dims": [2, 3, 1],
"type": "float32"
},
// zero
{
"data": [0, 1, 2, 3, 4, 5],
"dims": [2, 3, 1],
"type": "int4"
}
],
"outputs": [
{
"data": [
-1.5,0,1.5,3,4.5,6,-4.5,-3,-3,-1.5,0,1.5,3,4.5,6,-4.5,-7,-5.25,-3.5,-1.75,0,1.75,3.5,5.25,5.25,-7,-5.25,-3.5,-1.75,0,1.75,3.5,2,4,-10,-8,-6,-4,-2,0,0,2,4,-10,-8,-6,-4,-2
],
"dims": [1, 3, 16],
"type": "float32"
}
]
}
]
},
Expand Down

0 comments on commit 9b5eac4

Please sign in to comment.