diff --git a/js/web/lib/wasm/jsep/webgpu/ops/quantize_linear.ts b/js/web/lib/wasm/jsep/webgpu/ops/quantize_linear.ts index 105a6db7d731f..ed58bddf874d9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/quantize_linear.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/quantize_linear.ts @@ -14,21 +14,54 @@ export interface DequantizeLinerAttributes extends AttributeWithCacheKey { blockSize: number; } -const validateInputs = (inputs: readonly TensorView[]): void => { +const validateInputs = (inputs: readonly TensorView[], attributes: DequantizeLinerAttributes): void => { if (inputs.length < 2 || inputs.length > 3) { throw new Error('DequantizeLinear requires 2 or 3 inputs.'); } if (inputs.length === 3 && inputs[1].dims === inputs[2].dims) { - throw new Error('x-scale and x-zero-point must have the same shape'); + throw new Error('x-scale and x-zero-point must have the same shape.'); } if (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType) { - throw new Error('x and x-zero-point must have the same data type'); + throw new Error('x and x-zero-point must have the same data type.'); } if (inputs[0].dataType === DataType.int32 && inputs.length > 2) { throw new Error('In the case of dequantizing int32 there is no zero point.'); } if (inputs[1].dims.length !== 0 && inputs[1].dims.length !== 1 && inputs[1].dims.length !== inputs[0].dims.length) { - throw new Error('scale input must be a scalar, a 1D tensor, or have the same rank as the input tensor'); + throw new Error('scale input must be a scalar, a 1D tensor, or have the same rank as the input tensor.'); + } + // validate scale and zero-point input shapes + if (inputs.length > 2) { + // zero-point input type should be the same as input data type. + if (inputs[0].dataType !== inputs[2].dataType) { + throw new Error('x and x-zero-point must have the same data type.'); + } + // Scale and zero-point inputs must have the same shape + if (inputs[1].dims.length !== inputs[2].dims.length) { + throw new Error('scale and zero-point inputs must have the same rank.'); + } + if (!inputs[1].dims.map((d, i) => d === inputs[2].dims[i]).reduce((a, b) => a && b, true)) { + throw new Error('scale and zero-point inputs must have the same shape.'); + } + } + // Validate blockSize + if (attributes.blockSize > 0) { + // Block qunatization + if (inputs[1].dims.length === 0 || (inputs[1].dims.length === 1 && inputs[1].dims[0] === 1)) { + throw new Error('blockSize must be set only for block quantization.'); + } + if (!inputs[1].dims.map((d, i) => i === attributes.axis || d === inputs[0].dims[i]).reduce((a, b) => a && b, true)) { + throw new Error('For block qunatization, scale input shape to match the input shape except for the axis') + } + // Scale input rank should be same as the input rank + if (inputs[1].dims.length != inputs[0].dims.length) { + throw new Error('For block qunatization the scale input rank must be the same as the x rank.'); + } + const dI = inputs[0].dims[attributes.axis]; + const si = inputs[1].dims[attributes.axis]; + if (attributes.blockSize < Math.ceil(dI / si) || attributes.blockSize > Math.ceil(dI / (si - 1) - 1)) { + throw new Error('blockSize must be with in the range [ceil(dI / Si), ceil(dI / (Si - 1) - 1)].'); + } } }; @@ -164,7 +197,7 @@ const createDequantizeLinearProgramInfo = }; export const dequantizeLinear = (context: ComputeContext, attributes: DequantizeLinerAttributes): void => { - validateInputs(context.inputs); + validateInputs(context.inputs, attributes); context.compute(createDequantizeLinearProgramInfo(context.inputs, attributes)); };