Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/typed-decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Jul 22, 2024
2 parents 234858f + 842f38d commit 89f65f3
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing_extensions import TypeAlias

import onnxscript
from onnxscript import evaluator
from onnxscript import evaluator, ir
from onnxscript import tensor as onnxscript_tensor
from onnxscript._internal import param_manipulation, runtime_typing
from onnxscript.function_libs.torch_lib import _flags
Expand Down Expand Up @@ -440,6 +440,8 @@ def _add_attribute_to_torchscript_node(
return node.s_(key, value) # type: ignore[arg-type]
if isinstance(value, torch.Tensor):
return node.t_(key, value)
if isinstance(value, ir.TensorProtocol):
return node.t_(key, torch.from_dlpack(value))
if isinstance(value, Sequence):
if not value:
# Treat empty sequences as empty list tensors
Expand Down
26 changes: 24 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
# Licensed under the MIT License.
"""Common operators shared in the torchlib library."""

# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
from __future__ import annotations

import numpy.typing as npt
import onnx

import onnxscript
import onnxscript.values
from onnxscript import BOOL, INT64
from onnxscript import BOOL, INT64, ir
from onnxscript import opset18 as op
from onnxscript.function_libs.torch_lib import _constants, tensor_typing
from onnxscript.function_libs.torch_lib.tensor_typing import RealType
from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT
from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT, TensorType

COMPLEX64_TYPE = COMPLEX64.dtype
COMPLEX128_TYPE = COMPLEX128.dtype
Expand Down Expand Up @@ -56,3 +62,19 @@ def cast_to(a: RealType, dtype: int) -> RealType:
result = op.Cast(a, to=dtype)

return result


def constant(
array: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
dtype: int | onnx.TensorProto.DataType | ir.DataType,
) -> TensorType:
"""Utility for creating a constant tensor.
Args:
array: The array to convert to a constant tensor.
dtype: The data type of the tensor.
Returns:
A constant node.
"""
return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype)))
10 changes: 6 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

from onnxscript.function_libs.torch_lib.ops import common
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType
Expand All @@ -32,9 +33,8 @@ def quantized_decomposed_quantize_per_tensor(
quant_max: int,
dtype: int,
) -> TensorType:
# TODO(justinchuby): Use quant_min and quant_max
# TODO(justinchuby): Use dtype when we use opset 21
return op.QuantizeLinear(input, scale, zero_point)
return op.QuantizeLinear(input, scale, common.constant(zero_point, dtype=dtype))


@torch_op(
Expand All @@ -54,6 +54,8 @@ def quantized_decomposed_dequantize_per_tensor(
dtype: int,
out_dtype: int = -1,
) -> TensorType:
# TODO(justinchuby): Use quant_min and quant_max
# TODO(justinchuby): Use dtype when we use opset 21
return op.DequantizeLinear(input, scale, zero_point)
dequantized = op.DequantizeLinear(input, scale, common.constant(zero_point, dtype=dtype))
if out_dtype == -1:
return dequantized
return op.Cast(dequantized, to=out_dtype)
54 changes: 54 additions & 0 deletions tests/function_libs/torch_lib/quantization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Test quantized model export."""

from __future__ import annotations

import unittest

import onnx
import torch
import torch._export as torch_export
from torch.ao.quantization import quantize_pt2e
from torch.ao.quantization.quantizer import xnnpack_quantizer

from onnxscript._internal import version_utils


class QuantizedModelExportTest(unittest.TestCase):
@unittest.skipIf(
version_utils.torch_older_than("2.4"),
"Dynamo exporter fails at the modularization step.",
)
def test_simple_quantized_model(self):
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)

def forward(self, x):
return self.linear(x)

example_inputs = (torch.randn(1, 5),)
model = TestModel().eval()

# Step 1. program capture
pt2e_torch_model = torch_export.capture_pre_autograd_graph(model, example_inputs)

# Step 2. quantization
quantizer = xnnpack_quantizer.XNNPACKQuantizer().set_global(
xnnpack_quantizer.get_symmetric_quantization_config()
)
pt2e_torch_model = quantize_pt2e.prepare_pt2e(pt2e_torch_model, quantizer)

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*example_inputs)

# Convert the prepared model to a quantized model
pt2e_torch_model = quantize_pt2e.convert_pt2e(pt2e_torch_model, fold_quantize=False)
program = torch.onnx.dynamo_export(pt2e_torch_model, *example_inputs)
onnx.checker.check_model(program.model_proto, full_check=True)


if __name__ == "__main__":
unittest.main()

0 comments on commit 89f65f3

Please sign in to comment.