diff --git a/js/web/lib/wasm/jsep/webgpu/ops/quantize_linear.ts b/js/web/lib/wasm/jsep/webgpu/ops/quantize_linear.ts index ed58bddf874d9..ae7b71c096763 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/quantize_linear.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/quantize_linear.ts @@ -69,119 +69,131 @@ const createDequantizeLinearProgramInfo = (inputs: readonly TensorView[], attributes: DequantizeLinerAttributes): ProgramInfo => { const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length); const inputType = inputs[0].dataType; + const isSigned = inputType === DataType.int8; const outputShape = inputs[0].dims; // output shape is same as the input shape const dataType = inputs[1].dataType; // output type is same as the the scale input type const outputSize = ShapeUtil.size(outputShape); - const uniforms: UniformsArrayType = [{name: 'output_size', type: 'u32'}, {name: 'axis', type: 'u32'}]; + const uniforms: UniformsArrayType = + [{name: 'output_size', type: 'u32'}, {name: 'axis', type: 'u32'}, {name: 'block_size', type: 'u32'}]; const isPacked = inputType === DataType.int8 || inputType === DataType.uint8; - const inputShape = isPacked ? ShapeUtil.convertShape(inputs[0].dims).slice() : inputs[0].dims; + const inputShape = isPacked ? [Math.ceil(ShapeUtil.size(inputs[0].dims) / 4)] : inputs[0].dims; + const scaleShape = inputs[1].dims; const input = inputVariable('input', DataType.uint32, inputShape.length); - const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims.length); + const scale = inputVariable('scale', inputs[1].dataType, scaleShape.length); + const zeroPointInput = inputs.length > 2 ? inputs[2] : undefined; + const zeroPointShape = zeroPointInput ? + (isPacked ? [Math.ceil(ShapeUtil.size(zeroPointInput.dims) / 4)] : zeroPointInput.dims) : + undefined; + // Scales input is a scaler for per-tensor/per-layer quantization, 1-D tensor for per-axis quantization + // or tensor with same rank as input for blocked quantization. + const perLayerQuantization = scaleShape.length === 0 || (scaleShape.length === 1 && scaleShape[0] === 1); + const perAxisQuantization = perLayerQuantization === false && scaleShape.length === 1; + // Left unnecessary commented-out assignment for documentation + // const blockQuantization = perLayerQuantization === false && perAxisQuantization === false; const zeroPoint = - inputs.length > 2 ? inputVariable('zero_point', DataType.uint32, inputs[2].dims.length) : undefined; + zeroPointInput ? inputVariable('zero_point', DataType.uint32, zeroPointShape!.length) : undefined; const output = outputVariable('output', dataType, outputShape.length); const inputVariables = [input, scale]; if (zeroPoint) { inputVariables.push(zeroPoint); } + const inputShapes = [inputShape, scaleShape]; + if (zeroPointInput) { + inputShapes.push(zeroPointShape!); + } const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: axis}, - ...createTensorShapeVariables(...inputs.map((t) => t.dims), outputShape) + {type: DataType.uint32, data: attributes.blockSize}, ...createTensorShapeVariables(...inputShapes, outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let output_indices = ${output.offsetToIndices('global_idx')}; + // Set input x ${(() => { if (isPacked) { - const isSigned = inputType === DataType.uint8; return ` - let input = ${input.getByOffset('global_idx/4')}; + let input = ${input.getByOffset('global_idx / 4')}; let x_vec: vec4<${isSigned ? 'i32' : 'u32'}> = ${isSigned ? 'unpack4xI8(input)' : 'unpack4xU8(input)'}; - let x = x_vec[global_idx % 4];`; + let x_value = x_vec[global_idx % 4];`; } else { - return `let x = ${input.getByOffset('global_idx')};`; + return `let x_value = ${input.getByOffset('global_idx')};`; } })()}; // Set scale input ${(() => { - const shape = inputs[1].dims; - if (shape.length === 0 || (shape.length === 1 && shape[0] === 1)) { - // scale input is a scalar + if (perLayerQuantization) { + // scale input is a scalar () return ` - let scale = ${scale.getByOffset('0')}`; - } else if (shape.length === 1) { + let scale_value= ${scale.getByOffset('0')}`; + } else if (perAxisQuantization) { // scale input is a 1D tensor return ` - let input_indices = ${input.offsetToIndices('global_idx')}; - let input_index = ${input.indicesGet('input_indices', 'uniforms.axis')}; - let scale_indices: ${scale.type.indices}; - ${scale.indicesSet('scale_indices', 'input_index', '0')}; - scale = ${scale.getByOffset('scale_index')};`; + let scale_index = ${output.indicesGet('output_indices', 'uniforms.axis')}; + let scale_value= ${scale.getByOffset('scale_index')};`; } else { + // Block quantization. Scale input rank is same as input/output rank. return ` - scale = ${scale.getByOffset('global_idx')};`; + let scale_indices: ${scale.type.indices} = output_indices; + let index = ${scale.indicesGet('scale_indices', 'uniforms.axis')} / 'uniforms.block_size'; + ${scale.indicesSet('scale_indices', 'uniform.axis', 'index')}; + let scale_value= ${scale.getByIndices('scales_indices')};`; } })()}; // Set zero-point input ${(() => { if (zeroPoint) { - const isSigned = inputType === DataType.int8; - const shape = inputs[2].dims; - if (shape.length === 0 || (shape.length === 1 && shape[0] === 1)) { + if (perLayerQuantization) { // zero-point input is a scalar if (isPacked) { return ` let zero_point_input = ${zeroPoint.getByOffset('0')}; let zero_point_vec: vec4<${isSigned ? 'i32' : 'u32'}> = ${ isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'}; - let zero_point = zero_point_vec[0]`; + let zero_point_value= zero_point_vec[0]`; } else { - return `let zero_point = ${zeroPoint.getByOffset('0')}`; + return `let zero_point_value = ${zeroPoint.getByOffset('0')}`; } - // return `let zero_point = ${isSigned ? 'i32' : 'u32'}(${zeroPoint.getByOffset('0')});`; - } else if (shape.length === 1) { + } else if (perAxisQuantization) { // zero-point input is a 1D tensor if (isPacked) { return ` - let input_indices = ${input.offsetToIndices('global_idx')}; - let input_index = ${input.indicesGet('input_indices', 'uniforms.axis')}; - let zero_point_indices: ${zeroPoint.type.indices}; - ${zeroPoint.indicesSet('zero_point_indices', 'input_index', '0')}; - let zero_point_input = ${zeroPoint.getByOffset('zero_point_index')}; + let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')}; + let zero_point_input = ${zeroPoint.getByOffset('zero_point_index / 4')}; let zero_point_vec: vec4<${isSigned ? 'i32' : 'u32'}> = ${ isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'}; - let zero_point = zero_point_vec[global_idx % 4]`; + let zero_point_value = zero_point_vec[zero_point_index % 4]`; } else { return ` - let input_indices = ${input.offsetToIndices('global_idx')}; - let input_index = ${input.indicesGet('input_indices', 'uniforms.axis')}; - let zero_point_indices: ${zeroPoint.type.indices}; - ${zeroPoint.indicesSet('zero_point_indices', 'input_index', '0')}; - let zero_point = ${zeroPoint.getByOffset('zero_point_index')};`; + let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')}; + let zero_point_value = ${zeroPoint.getByOffset('zero_point_index')};`; } } else { - // blocked quantization + // BlockedQuantization. The zero-point input shape is same as the input shape except along axis. if (isPacked) { return ` - let zero_point_input = ${input.getByOffset('global_idx/4')}; + let zero_point_offset = ${scale.indicesToOffset('scale_indices')}; + let zero_point_input = ${zeroPoint.getByOffset('zero_point_offset / 4')}; let zero_point_vec: vec4<${isSigned ? 'i32' : 'u32'} = ${ isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'}; - let zero_point = zero_point_vec[global_idx % 4];`; + let zero_point_value = zero_point_vec[zero_point_offset % 4];`; } else { - return `let zero_point = ${isSigned ? 'i32' : 'u32'}(${zeroPoint.getByOffset('0')});`; + return `let zero_point_value = ${zeroPoint.getByIndices('scale_indices')};`; } } } else { - return 'let zero_point = 0;'; + return `let zero_point_value: ${tensorTypeToWsglStorageType(inputType)} = 0;`; } })()}; // Compute and write output - ${output.setByOffset('global_idx', `${tensorTypeToWsglStorageType(dataType)}(x - zero_point) * scale`)}; + ${ + output.setByOffset( + 'global_idx', `${tensorTypeToWsglStorageType(dataType)}(x_value - zero_point_value) * scale_value`)}; }`; return { name: 'DequantizeLinear',