Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jul 22, 2024
1 parent 64e6416 commit ce37828
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
21 changes: 15 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
from __future__ import annotations

from typing import TYPE_CHECKING
import numpy.typing as npt
import onnx

import onnxscript
import onnxscript.values
Expand All @@ -15,9 +16,6 @@
from onnxscript.function_libs.torch_lib.tensor_typing import RealType
from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT, TensorType

if TYPE_CHECKING:
import onnx

COMPLEX64_TYPE = COMPLEX64.dtype
COMPLEX128_TYPE = COMPLEX128.dtype

Expand Down Expand Up @@ -66,6 +64,17 @@ def cast_to(a: RealType, dtype: int) -> RealType:
return result


def constant(array, dtype: int | onnx.TensorProto.DataType | ir.DataType) -> TensorType:
"""Utility for creating a constant tensor."""
def constant(
array: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
dtype: int | onnx.TensorProto.DataType | ir.DataType,
) -> TensorType:
"""Utility for creating a constant tensor.
Args:
array: The array to convert to a constant tensor.
dtype: The data type of the tensor.
Returns:
A constant node.
"""
return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype)))

Check warning on line 80 in onnxscript/function_libs/torch_lib/ops/common.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/common.py#L80

Added line #L80 was not covered by tests
6 changes: 6 additions & 0 deletions tests/function_libs/torch_lib/quantization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@
from torch.ao.quantization import quantize_pt2e
from torch.ao.quantization.quantizer import xnnpack_quantizer

from onnxscript._internal import version_utils


class QuantizedModelExportTest(unittest.TestCase):
@unittest.skipIf(
version_utils.torch_older_than("2.4"),
"Dynamo exporter fails at the modularization step.",
)
def test_simple_quantized_model(self):
class TestModel(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit ce37828

Please sign in to comment.