Skip to content

Commit

Permalink
Added more input validation.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Aug 5, 2024
1 parent e293a08 commit e0b5fdd
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/quantize_linear.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,54 @@ export interface DequantizeLinerAttributes extends AttributeWithCacheKey {
blockSize: number;
}

const validateInputs = (inputs: readonly TensorView[]): void => {
const validateInputs = (inputs: readonly TensorView[], attributes: DequantizeLinerAttributes): void => {
if (inputs.length < 2 || inputs.length > 3) {
throw new Error('DequantizeLinear requires 2 or 3 inputs.');
}
if (inputs.length === 3 && inputs[1].dims === inputs[2].dims) {
throw new Error('x-scale and x-zero-point must have the same shape');
throw new Error('x-scale and x-zero-point must have the same shape.');
}
if (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType) {
throw new Error('x and x-zero-point must have the same data type');
throw new Error('x and x-zero-point must have the same data type.');
}
if (inputs[0].dataType === DataType.int32 && inputs.length > 2) {
throw new Error('In the case of dequantizing int32 there is no zero point.');
}
if (inputs[1].dims.length !== 0 && inputs[1].dims.length !== 1 && inputs[1].dims.length !== inputs[0].dims.length) {
throw new Error('scale input must be a scalar, a 1D tensor, or have the same rank as the input tensor');
throw new Error('scale input must be a scalar, a 1D tensor, or have the same rank as the input tensor.');
}
// validate scale and zero-point input shapes
if (inputs.length > 2) {
// zero-point input type should be the same as input data type.
if (inputs[0].dataType !== inputs[2].dataType) {
throw new Error('x and x-zero-point must have the same data type.');
}
// Scale and zero-point inputs must have the same shape
if (inputs[1].dims.length !== inputs[2].dims.length) {
throw new Error('scale and zero-point inputs must have the same rank.');
}
if (!inputs[1].dims.map((d, i) => d === inputs[2].dims[i]).reduce((a, b) => a && b, true)) {
throw new Error('scale and zero-point inputs must have the same shape.');
}
}
// Validate blockSize
if (attributes.blockSize > 0) {
// Block qunatization
if (inputs[1].dims.length === 0 || (inputs[1].dims.length === 1 && inputs[1].dims[0] === 1)) {
throw new Error('blockSize must be set only for block quantization.');
}
if (!inputs[1].dims.map((d, i) => i === attributes.axis || d === inputs[0].dims[i]).reduce((a, b) => a && b, true)) {
throw new Error('For block qunatization, scale input shape to match the input shape except for the axis')
}
// Scale input rank should be same as the input rank
if (inputs[1].dims.length != inputs[0].dims.length) {
throw new Error('For block qunatization the scale input rank must be the same as the x rank.');
}
const dI = inputs[0].dims[attributes.axis];
const si = inputs[1].dims[attributes.axis];
if (attributes.blockSize < Math.ceil(dI / si) || attributes.blockSize > Math.ceil(dI / (si - 1) - 1)) {
throw new Error('blockSize must be with in the range [ceil(dI / Si), ceil(dI / (Si - 1) - 1)].');
}
}
};

Expand Down Expand Up @@ -164,7 +197,7 @@ const createDequantizeLinearProgramInfo =
};

export const dequantizeLinear = (context: ComputeContext, attributes: DequantizeLinerAttributes): void => {
validateInputs(context.inputs);
validateInputs(context.inputs, attributes);
context.compute(createDequantizeLinearProgramInfo(context.inputs, attributes));
};

Expand Down

0 comments on commit e0b5fdd

Please sign in to comment.