Skip to content

Commit

Permalink
Merge branch 'main' into rama/simple-multi-output
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam authored Jul 24, 2024
2 parents 65f647d + 712aa87 commit ee10468
Show file tree
Hide file tree
Showing 16 changed files with 514 additions and 304 deletions.
4 changes: 3 additions & 1 deletion onnxscript/_internal/runtime_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
T = typing.TypeVar("T", bound=typing.Callable[..., typing.Any])

try:
from beartype import beartype as checked
from beartype import beartype as _beartype_decorator
from beartype import roar as _roar

checked = typing.cast(typing.Callable[[T], T], _beartype_decorator)

# Beartype warns when we import from typing because the types are deprecated
# in Python 3.9. But there will be a long time until we can move to using
# the native container types for type annotations (when 3.9 is the lowest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing_extensions import TypeAlias

import onnxscript
from onnxscript import evaluator
from onnxscript import evaluator, ir
from onnxscript import tensor as onnxscript_tensor
from onnxscript._internal import param_manipulation, runtime_typing
from onnxscript.function_libs.torch_lib import _flags
Expand Down Expand Up @@ -425,11 +425,19 @@ def eval_function( # type: ignore[override]
return self._graph.add_function_call(function, inputs, attributes)


@runtime_typing.checked
def _add_attribute_to_torchscript_node(
node: torch.Node,
key: str,
value: Union[float, int, str, bytes, Sequence[float], Sequence[int], torch.Tensor],
value: Union[
float,
int,
str,
bytes,
Sequence[float],
Sequence[int],
torch.Tensor,
ir.TensorProtocol,
],
):
"""Initializes the right attribute based on type of value."""
if isinstance(value, float):
Expand All @@ -440,6 +448,8 @@ def _add_attribute_to_torchscript_node(
return node.s_(key, value) # type: ignore[arg-type]
if isinstance(value, torch.Tensor):
return node.t_(key, value)
if isinstance(value, ir.TensorProtocol):
return node.t_(key, torch.from_dlpack(value))
if isinstance(value, Sequence):
if not value:
# Treat empty sequences as empty list tensors
Expand Down
26 changes: 24 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
# Licensed under the MIT License.
"""Common operators shared in the torchlib library."""

# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
from __future__ import annotations

import numpy.typing as npt
import onnx

import onnxscript
import onnxscript.values
from onnxscript import BOOL, INT64
from onnxscript import BOOL, INT64, ir
from onnxscript import opset18 as op
from onnxscript.function_libs.torch_lib import _constants, tensor_typing
from onnxscript.function_libs.torch_lib.tensor_typing import RealType
from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT
from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT, TensorType

COMPLEX64_TYPE = COMPLEX64.dtype
COMPLEX128_TYPE = COMPLEX128.dtype
Expand Down Expand Up @@ -56,3 +62,19 @@ def cast_to(a: RealType, dtype: int) -> RealType:
result = op.Cast(a, to=dtype)

return result


def constant(
array: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
dtype: int | onnx.TensorProto.DataType | ir.DataType,
) -> TensorType:
"""Utility for creating a constant tensor.
Args:
array: The array to convert to a constant tensor.
dtype: The data type of the tensor.
Returns:
A constant node.
"""
return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype)))
Loading

0 comments on commit ee10468

Please sign in to comment.