Skip to content

Commit

Permalink
constant
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jul 19, 2024
1 parent 58cdcd6 commit 26b335a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing_extensions import TypeAlias

import onnxscript
from onnxscript import evaluator
from onnxscript import evaluator, ir
from onnxscript import tensor as onnxscript_tensor
from onnxscript._internal import param_manipulation, runtime_typing
from onnxscript.function_libs.torch_lib import _flags
Expand Down Expand Up @@ -440,6 +440,8 @@ def _add_attribute_to_torchscript_node(
return node.s_(key, value) # type: ignore[arg-type]
if isinstance(value, torch.Tensor):
return node.t_(key, value)
if isinstance(value, ir.TensorProtocol):
return node.t_(key, torch.from_dlpack(value))
if isinstance(value, Sequence):
if not value:
# Treat empty sequences as empty list tensors
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""Common operators shared in the torchlib library."""

# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
from __future__ import annotations

from typing import TYPE_CHECKING

import onnxscript
Expand Down Expand Up @@ -66,6 +68,4 @@ def cast_to(a: RealType, dtype: int) -> RealType:

def constant(array, dtype: int | onnx.TensorProto.DataType | ir.DataType) -> TensorType:
"""Utility for creating a constant tensor."""
return op.Constant(
value=ir.serde.serialize_tensor(ir.tensor(array, dtype=ir.DataType(dtype)))
)
return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype)))

0 comments on commit 26b335a

Please sign in to comment.