From 26b335a57f6dc8710e56c5ad4b9aad1cc4cc55cb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 19 Jul 2024 00:04:02 +0000 Subject: [PATCH] constant --- .../torch_lib/graph_building/_graph_building_torch.py | 4 +++- onnxscript/function_libs/torch_lib/ops/common.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 5e0a48077..54aa412ff 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -19,7 +19,7 @@ from typing_extensions import TypeAlias import onnxscript -from onnxscript import evaluator +from onnxscript import evaluator, ir from onnxscript import tensor as onnxscript_tensor from onnxscript._internal import param_manipulation, runtime_typing from onnxscript.function_libs.torch_lib import _flags @@ -440,6 +440,8 @@ def _add_attribute_to_torchscript_node( return node.s_(key, value) # type: ignore[arg-type] if isinstance(value, torch.Tensor): return node.t_(key, value) + if isinstance(value, ir.TensorProtocol): + return node.t_(key, torch.from_dlpack(value)) if isinstance(value, Sequence): if not value: # Treat empty sequences as empty list tensors diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index 87a55620a..3e0a2ec8a 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -3,6 +3,8 @@ """Common operators shared in the torchlib library.""" # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" +from __future__ import annotations + from typing import TYPE_CHECKING import onnxscript @@ -66,6 +68,4 @@ def cast_to(a: RealType, dtype: int) -> RealType: def constant(array, dtype: int | onnx.TensorProto.DataType | ir.DataType) -> TensorType: """Utility for creating a constant tensor.""" - return op.Constant( - value=ir.serde.serialize_tensor(ir.tensor(array, dtype=ir.DataType(dtype))) - ) + return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype)))