Skip to content

Commit

Permalink
Added block dequnatization.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Aug 5, 2024
1 parent e0b5fdd commit 1e6c46f
Showing 1 changed file with 56 additions and 44 deletions.
100 changes: 56 additions & 44 deletions js/web/lib/wasm/jsep/webgpu/ops/quantize_linear.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,119 +69,131 @@ const createDequantizeLinearProgramInfo =
(inputs: readonly TensorView[], attributes: DequantizeLinerAttributes): ProgramInfo => {
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length);
const inputType = inputs[0].dataType;
const isSigned = inputType === DataType.int8;
const outputShape = inputs[0].dims; // output shape is same as the input shape
const dataType = inputs[1].dataType; // output type is same as the the scale input type
const outputSize = ShapeUtil.size(outputShape);
const uniforms: UniformsArrayType = [{name: 'output_size', type: 'u32'}, {name: 'axis', type: 'u32'}];
const uniforms: UniformsArrayType =
[{name: 'output_size', type: 'u32'}, {name: 'axis', type: 'u32'}, {name: 'block_size', type: 'u32'}];
const isPacked = inputType === DataType.int8 || inputType === DataType.uint8;
const inputShape = isPacked ? ShapeUtil.convertShape(inputs[0].dims).slice() : inputs[0].dims;
const inputShape = isPacked ? [Math.ceil(ShapeUtil.size(inputs[0].dims) / 4)] : inputs[0].dims;
const scaleShape = inputs[1].dims;
const input = inputVariable('input', DataType.uint32, inputShape.length);
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims.length);
const scale = inputVariable('scale', inputs[1].dataType, scaleShape.length);
const zeroPointInput = inputs.length > 2 ? inputs[2] : undefined;
const zeroPointShape = zeroPointInput ?
(isPacked ? [Math.ceil(ShapeUtil.size(zeroPointInput.dims) / 4)] : zeroPointInput.dims) :
undefined;
// Scales input is a scaler for per-tensor/per-layer quantization, 1-D tensor for per-axis quantization
// or tensor with same rank as input for blocked quantization.
const perLayerQuantization = scaleShape.length === 0 || (scaleShape.length === 1 && scaleShape[0] === 1);
const perAxisQuantization = perLayerQuantization === false && scaleShape.length === 1;
// Left unnecessary commented-out assignment for documentation
// const blockQuantization = perLayerQuantization === false && perAxisQuantization === false;
const zeroPoint =
inputs.length > 2 ? inputVariable('zero_point', DataType.uint32, inputs[2].dims.length) : undefined;
zeroPointInput ? inputVariable('zero_point', DataType.uint32, zeroPointShape!.length) : undefined;
const output = outputVariable('output', dataType, outputShape.length);
const inputVariables = [input, scale];
if (zeroPoint) {
inputVariables.push(zeroPoint);
}
const inputShapes = [inputShape, scaleShape];
if (zeroPointInput) {
inputShapes.push(zeroPointShape!);
}
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: axis},
...createTensorShapeVariables(...inputs.map((t) => t.dims), outputShape)
{type: DataType.uint32, data: attributes.blockSize}, ...createTensorShapeVariables(...inputShapes, outputShape)
];

const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let output_indices = ${output.offsetToIndices('global_idx')};
// Set input x
${(() => {
if (isPacked) {
const isSigned = inputType === DataType.uint8;
return `
let input = ${input.getByOffset('global_idx/4')};
let input = ${input.getByOffset('global_idx / 4')};
let x_vec: vec4<${isSigned ? 'i32' : 'u32'}> = ${isSigned ? 'unpack4xI8(input)' : 'unpack4xU8(input)'};
let x = x_vec[global_idx % 4];`;
let x_value = x_vec[global_idx % 4];`;
} else {
return `let x = ${input.getByOffset('global_idx')};`;
return `let x_value = ${input.getByOffset('global_idx')};`;
}
})()};
// Set scale input
${(() => {
const shape = inputs[1].dims;
if (shape.length === 0 || (shape.length === 1 && shape[0] === 1)) {
// scale input is a scalar
if (perLayerQuantization) {
// scale input is a scalar ()
return `
let scale = ${scale.getByOffset('0')}`;
} else if (shape.length === 1) {
let scale_value= ${scale.getByOffset('0')}`;
} else if (perAxisQuantization) {
// scale input is a 1D tensor
return `
let input_indices = ${input.offsetToIndices('global_idx')};
let input_index = ${input.indicesGet('input_indices', 'uniforms.axis')};
let scale_indices: ${scale.type.indices};
${scale.indicesSet('scale_indices', 'input_index', '0')};
scale = ${scale.getByOffset('scale_index')};`;
let scale_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
let scale_value= ${scale.getByOffset('scale_index')};`;
} else {
// Block quantization. Scale input rank is same as input/output rank.
return `
scale = ${scale.getByOffset('global_idx')};`;
let scale_indices: ${scale.type.indices} = output_indices;
let index = ${scale.indicesGet('scale_indices', 'uniforms.axis')} / 'uniforms.block_size';
${scale.indicesSet('scale_indices', 'uniform.axis', 'index')};
let scale_value= ${scale.getByIndices('scales_indices')};`;
}
})()};
// Set zero-point input
${(() => {
if (zeroPoint) {
const isSigned = inputType === DataType.int8;
const shape = inputs[2].dims;
if (shape.length === 0 || (shape.length === 1 && shape[0] === 1)) {
if (perLayerQuantization) {
// zero-point input is a scalar
if (isPacked) {
return `
let zero_point_input = ${zeroPoint.getByOffset('0')};
let zero_point_vec: vec4<${isSigned ? 'i32' : 'u32'}> = ${
isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
let zero_point = zero_point_vec[0]`;
let zero_point_value= zero_point_vec[0]`;
} else {
return `let zero_point = ${zeroPoint.getByOffset('0')}`;
return `let zero_point_value = ${zeroPoint.getByOffset('0')}`;
}
// return `let zero_point = ${isSigned ? 'i32' : 'u32'}(${zeroPoint.getByOffset('0')});`;
} else if (shape.length === 1) {
} else if (perAxisQuantization) {
// zero-point input is a 1D tensor
if (isPacked) {
return `
let input_indices = ${input.offsetToIndices('global_idx')};
let input_index = ${input.indicesGet('input_indices', 'uniforms.axis')};
let zero_point_indices: ${zeroPoint.type.indices};
${zeroPoint.indicesSet('zero_point_indices', 'input_index', '0')};
let zero_point_input = ${zeroPoint.getByOffset('zero_point_index')};
let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
let zero_point_input = ${zeroPoint.getByOffset('zero_point_index / 4')};
let zero_point_vec: vec4<${isSigned ? 'i32' : 'u32'}> = ${
isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
let zero_point = zero_point_vec[global_idx % 4]`;
let zero_point_value = zero_point_vec[zero_point_index % 4]`;
} else {
return `
let input_indices = ${input.offsetToIndices('global_idx')};
let input_index = ${input.indicesGet('input_indices', 'uniforms.axis')};
let zero_point_indices: ${zeroPoint.type.indices};
${zeroPoint.indicesSet('zero_point_indices', 'input_index', '0')};
let zero_point = ${zeroPoint.getByOffset('zero_point_index')};`;
let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
let zero_point_value = ${zeroPoint.getByOffset('zero_point_index')};`;
}
} else {
// blocked quantization
// BlockedQuantization. The zero-point input shape is same as the input shape except along axis.
if (isPacked) {
return `
let zero_point_input = ${input.getByOffset('global_idx/4')};
let zero_point_offset = ${scale.indicesToOffset('scale_indices')};
let zero_point_input = ${zeroPoint.getByOffset('zero_point_offset / 4')};
let zero_point_vec: vec4<${isSigned ? 'i32' : 'u32'} = ${
isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
let zero_point = zero_point_vec[global_idx % 4];`;
let zero_point_value = zero_point_vec[zero_point_offset % 4];`;
} else {
return `let zero_point = ${isSigned ? 'i32' : 'u32'}(${zeroPoint.getByOffset('0')});`;
return `let zero_point_value = ${zeroPoint.getByIndices('scale_indices')};`;
}
}
} else {
return 'let zero_point = 0;';
return `let zero_point_value: ${tensorTypeToWsglStorageType(inputType)} = 0;`;
}
})()};
// Compute and write output
${output.setByOffset('global_idx', `${tensorTypeToWsglStorageType(dataType)}(x - zero_point) * scale`)};
${
output.setByOffset(
'global_idx', `${tensorTypeToWsglStorageType(dataType)}(x_value - zero_point_value) * scale_value`)};
}`;
return {
name: 'DequantizeLinear',
Expand Down

0 comments on commit 1e6c46f

Please sign in to comment.