Skip to content

Commit

Permalink
common
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jul 18, 2024
1 parent 69cddd5 commit 58cdcd6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
17 changes: 15 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
# Licensed under the MIT License.
"""Common operators shared in the torchlib library."""

# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
from typing import TYPE_CHECKING

import onnxscript
import onnxscript.values
from onnxscript import BOOL, INT64
from onnxscript import BOOL, INT64, ir
from onnxscript import opset18 as op
from onnxscript.function_libs.torch_lib import _constants, tensor_typing
from onnxscript.function_libs.torch_lib.tensor_typing import RealType
from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT
from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT, TensorType

if TYPE_CHECKING:
import onnx

Check warning on line 17 in onnxscript/function_libs/torch_lib/ops/common.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/common.py#L17

Added line #L17 was not covered by tests

COMPLEX64_TYPE = COMPLEX64.dtype
COMPLEX128_TYPE = COMPLEX128.dtype
Expand Down Expand Up @@ -56,3 +62,10 @@ def cast_to(a: RealType, dtype: int) -> RealType:
result = op.Cast(a, to=dtype)

return result


def constant(array, dtype: int | onnx.TensorProto.DataType | ir.DataType) -> TensorType:

Check failure

Code scanning / lintrunner

PYLINT/E0601 Error

Using variable 'onnx' before assignment (used-before-assignment)
See used-before-assignment. To disable, use # pylint: disable=used-before-assignment
"""Utility for creating a constant tensor."""
return op.Constant(

Check warning on line 69 in onnxscript/function_libs/torch_lib/ops/common.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/common.py#L69

Added line #L69 was not covered by tests
value=ir.serde.serialize_tensor(ir.tensor(array, dtype=ir.DataType(dtype)))
)
15 changes: 3 additions & 12 deletions onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,12 @@

from __future__ import annotations

import onnx

from onnxscript import ir
from onnxscript.function_libs.torch_lib.ops import common
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType


def constant(array, dtype: int | onnx.TensorProto.DataType):
"""Utility for creating a constant tensor."""
return op.Constant(
value=ir.serde.serialize_tensor(ir.tensor(array, dtype=ir.DataType(dtype)))
)


@torch_op(
(
"quantized_decomposed::quantize_per_tensor",
Expand All @@ -43,7 +34,7 @@ def quantized_decomposed_quantize_per_tensor(
dtype: int,
) -> TensorType:
# TODO(justinchuby): Use dtype when we use opset 21
return op.QuantizeLinear(input, scale, constant(zero_point, dtype=dtype))
return op.QuantizeLinear(input, scale, common.constant(zero_point, dtype=dtype))


@torch_op(
Expand All @@ -64,7 +55,7 @@ def quantized_decomposed_dequantize_per_tensor(
out_dtype: int = -1,
) -> TensorType:
# TODO(justinchuby): Use dtype when we use opset 21
dequantized = op.DequantizeLinear(input, scale, constant(zero_point, dtype=dtype))
dequantized = op.DequantizeLinear(input, scale, common.constant(zero_point, dtype=dtype))
if out_dtype == -1:
return dequantized
return op.Cast(dequantized, to=out_dtype)

0 comments on commit 58cdcd6

Please sign in to comment.