diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index cae319e2e..87a55620a 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -2,13 +2,19 @@ # Licensed under the MIT License. """Common operators shared in the torchlib library.""" +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" +from typing import TYPE_CHECKING + import onnxscript import onnxscript.values -from onnxscript import BOOL, INT64 +from onnxscript import BOOL, INT64, ir from onnxscript import opset18 as op from onnxscript.function_libs.torch_lib import _constants, tensor_typing from onnxscript.function_libs.torch_lib.tensor_typing import RealType -from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT +from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT, TensorType + +if TYPE_CHECKING: + import onnx COMPLEX64_TYPE = COMPLEX64.dtype COMPLEX128_TYPE = COMPLEX128.dtype @@ -56,3 +62,10 @@ def cast_to(a: RealType, dtype: int) -> RealType: result = op.Cast(a, to=dtype) return result + + +def constant(array, dtype: int | onnx.TensorProto.DataType | ir.DataType) -> TensorType: + """Utility for creating a constant tensor.""" + return op.Constant( + value=ir.serde.serialize_tensor(ir.tensor(array, dtype=ir.DataType(dtype))) + ) diff --git a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py index 4ea6b73af..fa2df9751 100644 --- a/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py +++ b/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py @@ -11,21 +11,12 @@ from __future__ import annotations -import onnx - -from onnxscript import ir +from onnxscript.function_libs.torch_lib.ops import common from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType -def constant(array, dtype: int | onnx.TensorProto.DataType): - """Utility for creating a constant tensor.""" - return op.Constant( - value=ir.serde.serialize_tensor(ir.tensor(array, dtype=ir.DataType(dtype))) - ) - - @torch_op( ( "quantized_decomposed::quantize_per_tensor", @@ -43,7 +34,7 @@ def quantized_decomposed_quantize_per_tensor( dtype: int, ) -> TensorType: # TODO(justinchuby): Use dtype when we use opset 21 - return op.QuantizeLinear(input, scale, constant(zero_point, dtype=dtype)) + return op.QuantizeLinear(input, scale, common.constant(zero_point, dtype=dtype)) @torch_op( @@ -64,7 +55,7 @@ def quantized_decomposed_dequantize_per_tensor( out_dtype: int = -1, ) -> TensorType: # TODO(justinchuby): Use dtype when we use opset 21 - dequantized = op.DequantizeLinear(input, scale, constant(zero_point, dtype=dtype)) + dequantized = op.DequantizeLinear(input, scale, common.constant(zero_point, dtype=dtype)) if out_dtype == -1: return dequantized return op.Cast(dequantized, to=out_dtype)