diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 5e0a48077..54aa412ff 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -19,7 +19,7 @@ from typing_extensions import TypeAlias import onnxscript -from onnxscript import evaluator +from onnxscript import evaluator, ir from onnxscript import tensor as onnxscript_tensor from onnxscript._internal import param_manipulation, runtime_typing from onnxscript.function_libs.torch_lib import _flags @@ -440,6 +440,8 @@ def _add_attribute_to_torchscript_node( return node.s_(key, value) # type: ignore[arg-type] if isinstance(value, torch.Tensor): return node.t_(key, value) + if isinstance(value, ir.TensorProtocol): + return node.t_(key, torch.from_dlpack(value)) if isinstance(value, Sequence): if not value: # Treat empty sequences as empty list tensors diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index cae319e2e..d7784a528 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -2,13 +2,19 @@ # Licensed under the MIT License. """Common operators shared in the torchlib library.""" +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" +from __future__ import annotations + +import numpy.typing as npt +import onnx + import onnxscript import onnxscript.values -from onnxscript import BOOL, INT64 +from onnxscript import BOOL, INT64, ir from onnxscript import opset18 as op from onnxscript.function_libs.torch_lib import _constants, tensor_typing from onnxscript.function_libs.torch_lib.tensor_typing import RealType -from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT +from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT, TensorType COMPLEX64_TYPE = COMPLEX64.dtype COMPLEX128_TYPE = COMPLEX128.dtype @@ -56,3 +62,19 @@ def cast_to(a: RealType, dtype: int) -> RealType: result = op.Cast(a, to=dtype) return result + + +def constant( + array: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible, + dtype: int | onnx.TensorProto.DataType | ir.DataType, +) -> TensorType: + """Utility for creating a constant tensor. + + Args: + array: The array to convert to a constant tensor. + dtype: The data type of the tensor. + + Returns: + A constant node. + """ + return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype))) diff --git a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py index 9df42b2af..fa2df9751 100644 --- a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py +++ b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py @@ -11,6 +11,7 @@ from __future__ import annotations +from onnxscript.function_libs.torch_lib.ops import common from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -32,9 +33,8 @@ def quantized_decomposed_quantize_per_tensor( quant_max: int, dtype: int, ) -> TensorType: - # TODO(justinchuby): Use quant_min and quant_max # TODO(justinchuby): Use dtype when we use opset 21 - return op.QuantizeLinear(input, scale, zero_point) + return op.QuantizeLinear(input, scale, common.constant(zero_point, dtype=dtype)) @torch_op( @@ -54,6 +54,8 @@ def quantized_decomposed_dequantize_per_tensor( dtype: int, out_dtype: int = -1, ) -> TensorType: - # TODO(justinchuby): Use quant_min and quant_max # TODO(justinchuby): Use dtype when we use opset 21 - return op.DequantizeLinear(input, scale, zero_point) + dequantized = op.DequantizeLinear(input, scale, common.constant(zero_point, dtype=dtype)) + if out_dtype == -1: + return dequantized + return op.Cast(dequantized, to=out_dtype) diff --git a/tests/function_libs/torch_lib/quantization_test.py b/tests/function_libs/torch_lib/quantization_test.py new file mode 100644 index 000000000..7ec04ee77 --- /dev/null +++ b/tests/function_libs/torch_lib/quantization_test.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Test quantized model export.""" + +from __future__ import annotations + +import unittest + +import onnx +import torch +import torch._export as torch_export +from torch.ao.quantization import quantize_pt2e +from torch.ao.quantization.quantizer import xnnpack_quantizer + +from onnxscript._internal import version_utils + + +class QuantizedModelExportTest(unittest.TestCase): + @unittest.skipIf( + version_utils.torch_older_than("2.4"), + "Dynamo exporter fails at the modularization step.", + ) + def test_simple_quantized_model(self): + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + example_inputs = (torch.randn(1, 5),) + model = TestModel().eval() + + # Step 1. program capture + pt2e_torch_model = torch_export.capture_pre_autograd_graph(model, example_inputs) + + # Step 2. quantization + quantizer = xnnpack_quantizer.XNNPACKQuantizer().set_global( + xnnpack_quantizer.get_symmetric_quantization_config() + ) + pt2e_torch_model = quantize_pt2e.prepare_pt2e(pt2e_torch_model, quantizer) + + # Run the prepared model with sample input data to ensure that internal observers are populated with correct values + pt2e_torch_model(*example_inputs) + + # Convert the prepared model to a quantized model + pt2e_torch_model = quantize_pt2e.convert_pt2e(pt2e_torch_model, fold_quantize=False) + program = torch.onnx.dynamo_export(pt2e_torch_model, *example_inputs) + onnx.checker.check_model(program.model_proto, full_check=True) + + +if __name__ == "__main__": + unittest.main()