Skip to content

Commit

Permalink
[torchlib] Implement quantize/dequantize operators
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jul 17, 2024
1 parent f8ee736 commit 31513ef
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
60 changes: 60 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
# pylint: disable=unused-argument
"""quantized_decomposed ops defined in https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py
- No inplace operators.
- All functions should not have the script() decorator. This is because
we want to delay the compilation of the function.
"""

from __future__ import annotations

from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType


@torch_op(
(
"quantized_decomposed::quantize_per_tensor",
"quantized_decomposed::quantize_per_tensor.tensor",
"quantized_decomposed::quantize_per_tensor.tensor2",
),
trace_only=True,
)
def quantized_decomposed_quantize_per_tensor(
input: TensorType,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: int,
) -> TensorType:
# TODO(justinchuby): Use quant_min and quant_max
# TODO(justinchuby): Use dtype when we use opset 21
return op.QuantizeLinear(input, scale, zero_point)

Check warning on line 37 in onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py#L37

Added line #L37 was not covered by tests


@torch_op(
(
"quantized_decomposed::dequantize_per_tensor",
"quantized_decomposed::dequantize_per_tensor.tensor",
"quantized_decomposed::dequantize_per_tensor.tensor2",
),
trace_only=True,
)
def quantized_decomposed_dequantize_per_tensor(
input: TensorType,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: int,
*,
out_dtype: int | None = None,
) -> TensorType:
# TODO(justinchuby): Use quant_min and quant_max
# TODO(justinchuby): Use dtype when we use opset 21
return op.DequantizeLinear(input, scale, zero_point)

Check warning on line 60 in onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py#L60

Added line #L60 was not covered by tests
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def torch_op(
private: bool = False,
complex: bool = False,
traceable: bool = False,
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
) -> Callable[[Callable], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
"""Register a torch op.
Args:
Expand Down

0 comments on commit 31513ef

Please sign in to comment.