From 31513efb8c4b731234871969f019574487dba5ba Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 17 Jul 2024 01:21:20 +0000 Subject: [PATCH] [torchlib] Implement quantize/dequantize operators --- .../torch_lib/ops/quantized_decomposed.py | 60 +++++++++++++++++++ .../function_libs/torch_lib/registration.py | 2 +- 2 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py diff --git a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py new file mode 100644 index 000000000..30df24e09 --- /dev/null +++ b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" +# pylint: disable=unused-argument +"""quantized_decomposed ops defined in https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py + +- No inplace operators. +- All functions should not have the script() decorator. This is because + we want to delay the compilation of the function. +""" + +from __future__ import annotations + +from onnxscript.function_libs.torch_lib.registration import torch_op +from onnxscript.onnx_opset import opset18 as op +from onnxscript.onnx_types import TensorType + + +@torch_op( + ( + "quantized_decomposed::quantize_per_tensor", + "quantized_decomposed::quantize_per_tensor.tensor", + "quantized_decomposed::quantize_per_tensor.tensor2", + ), + trace_only=True, +) +def quantized_decomposed_quantize_per_tensor( + input: TensorType, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: int, +) -> TensorType: + # TODO(justinchuby): Use quant_min and quant_max + # TODO(justinchuby): Use dtype when we use opset 21 + return op.QuantizeLinear(input, scale, zero_point) + + +@torch_op( + ( + "quantized_decomposed::dequantize_per_tensor", + "quantized_decomposed::dequantize_per_tensor.tensor", + "quantized_decomposed::dequantize_per_tensor.tensor2", + ), + trace_only=True, +) +def quantized_decomposed_dequantize_per_tensor( + input: TensorType, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: int, + *, + out_dtype: int | None = None, +) -> TensorType: + # TODO(justinchuby): Use quant_min and quant_max + # TODO(justinchuby): Use dtype when we use opset 21 + return op.DequantizeLinear(input, scale, zero_point) diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index 505edee06..6e706834e 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -102,7 +102,7 @@ def torch_op( private: bool = False, complex: bool = False, traceable: bool = False, -) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]: +) -> Callable[[Callable], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]: """Register a torch op. Args: