diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 98219a595..bc20bb3f9 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1193,7 +1193,6 @@ def aten_binomial( @torch_op( ( - "aten::bitwise_and", "aten::bitwise_and.Tensor", "aten::bitwise_and.Scalar", "aten::bitwise_and.Scalar_Tensor", @@ -1207,7 +1206,13 @@ def aten_bitwise_and(self: TInt, other: TInt) -> TInt: return op.BitwiseAnd(self, other) -@torch_op("aten::bitwise_left_shift") +@torch_op( + ( + "aten::bitwise_left_shift.Tensor", + "aten::bitwise_left_shift.Tensor_Scalar", + "aten::bitwise_left_shift.Scalar_Tensor", + ) +) def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16: """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" # assert other >= 0 @@ -1219,7 +1224,13 @@ def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16: return op.Cast(result, to=INT16.dtype) -@torch_op("aten::bitwise_left_shift") +@torch_op( + ( + "aten::bitwise_left_shift.Tensor", + "aten::bitwise_left_shift.Tensor_Scalar", + "aten::bitwise_left_shift.Scalar_Tensor", + ) +) def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32: """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" # assert other >= 0 @@ -1231,7 +1242,13 @@ def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32: return op.Cast(result, to=INT32.dtype) -@torch_op("aten::bitwise_left_shift") +@torch_op( + ( + "aten::bitwise_left_shift.Tensor", + "aten::bitwise_left_shift.Tensor_Scalar", + "aten::bitwise_left_shift.Scalar_Tensor", + ) +) def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64: """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" # assert other >= 0 @@ -1243,7 +1260,13 @@ def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64: return op.Cast(result, to=INT64.dtype) -@torch_op("aten::bitwise_left_shift") +@torch_op( + ( + "aten::bitwise_left_shift.Tensor", + "aten::bitwise_left_shift.Tensor_Scalar", + "aten::bitwise_left_shift.Scalar_Tensor", + ) +) def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8: """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" # assert other >= 0 @@ -1265,7 +1288,6 @@ def aten_bitwise_not(self: TInt) -> TInt: @torch_op( ( - "aten::bitwise_or", "aten::bitwise_or.Tensor", "aten::bitwise_or.Scalar", "aten::bitwise_or.Scalar_Tensor", @@ -1279,7 +1301,13 @@ def aten_bitwise_or(self: TInt, other: TInt) -> TInt: return op.BitwiseOr(self, other) -@torch_op("aten::bitwise_right_shift") +@torch_op( + ( + "aten::bitwise_right_shift.Tensor", + "aten::bitwise_right_shift.Tensor_Scalar", + "aten::bitwise_right_shift.Scalar_Tensor", + ) +) def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" negative = op.Less(self, 0) @@ -1302,7 +1330,13 @@ def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: ) -@torch_op("aten::bitwise_right_shift") +@torch_op( + ( + "aten::bitwise_right_shift.Tensor", + "aten::bitwise_right_shift.Tensor_Scalar", + "aten::bitwise_right_shift.Scalar_Tensor", + ) +) def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" negative = op.Less(self, 0) @@ -1325,7 +1359,13 @@ def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: ) -@torch_op("aten::bitwise_right_shift") +@torch_op( + ( + "aten::bitwise_right_shift.Tensor", + "aten::bitwise_right_shift.Tensor_Scalar", + "aten::bitwise_right_shift.Scalar_Tensor", + ) +) def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" negative = op.Less(self, 0) @@ -1351,7 +1391,13 @@ def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: ) -@torch_op("aten::bitwise_right_shift") +@torch_op( + ( + "aten::bitwise_right_shift.Tensor", + "aten::bitwise_right_shift.Tensor_Scalar", + "aten::bitwise_right_shift.Scalar_Tensor", + ) +) def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" negative = op.Less(self, 0) @@ -1376,7 +1422,6 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: @torch_op( ( - "aten::bitwise_xor", "aten::bitwise_xor.Tensor", "aten::bitwise_xor.Scalar", "aten::bitwise_xor.Scalar_Tensor", @@ -1450,15 +1495,13 @@ def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: return aten_cat(tensors, dim=dim) -@torch_op("aten::cat") +@torch_op(("aten::cat", "aten::concat", "aten::concatenate"), trace_only=True) def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: """cat(Tensor[] tensors, int dim=0) -> Tensor""" - # NOTE: Having empty tensors when concatenating along non-zero dimension - # is not supported. - # TODO(justinchuby): Filter these tensors out with Sequence ops before - # calling ConcatFromSequence. - return op.ConcatFromSequence(tensors, axis=dim) + # Remove None tensors + tensors = [tensor for tensor in tensors if tensor is not None] + return op.Concat(*tensors, axis=dim) def aten_ccol_indices(self: TensorType) -> TensorType: @@ -1687,22 +1730,6 @@ def aten_complex(real: TFloat, imag: TFloat) -> TFloat: return _aten_complex(real, imag) -@torch_op("aten::concat") -def aten_concat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: - """concat(Tensor[] tensors, int dim=0) -> Tensor""" - - # TODO(justinchuby): Combine the implementation with cat - return op.ConcatFromSequence(tensors, axis=dim) - - -@torch_op("aten::concatenate") -def aten_concatenate(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: - """concatenate(Tensor[] tensors, int dim=0) -> Tensor""" - - # TODO(justinchuby): Combine the implementation with cat - return op.ConcatFromSequence(tensors, axis=dim) - - @torch_op("aten::conj") def aten_conj(self: TTensor) -> TTensor: """conj(Tensor(a) self) -> Tensor(a)""" @@ -2117,7 +2144,11 @@ def aten_copy( def aten__to_copy( self: TTensor, dtype: int = -1, + layout: str = "", # pylint: disable=unused-argument + device: str = "", # pylint: disable=unused-argument + pin_memory: bool = False, # pylint: disable=unused-argument non_blocking: bool = False, # pylint: disable=unused-argument + memory_format: str = "", # pylint: disable=unused-argument ) -> TTensor: """_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor""" @@ -2686,15 +2717,16 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType @torch_op( ( - "aten::div", "aten::div.Tensor", "aten::div.Scalar", # When rounding_mode is None, performs a true division # https://pytorch.org/docs/stable/generated/torch.div.html "aten::div.Tensor_mode", "aten::div.Scalar_mode", - "aten::divide", - "aten::true_divide", + "aten::divide.Tensor", + "aten::divide.Scalar", + "aten::true_divide.Tensor", + "aten::true_divide.Scalar", "_operator::truediv", ) ) @@ -2707,11 +2739,12 @@ def aten_div(self: TFloat, other: TFloat) -> TFloat: @torch_op( ( - "aten::div", "aten::div.Tensor", "aten::div.Scalar", - "aten::divide", - "aten::true_divide", + "aten::divide.Tensor", + "aten::divide.Scalar", + "aten::true_divide.Tensor", + "aten::true_divide.Scalar", "_operator::truediv", ), complex=True, @@ -2819,7 +2852,7 @@ def aten_einsum( @torch_op("aten::embedding") def aten_embedding( weight: TTensor, - indices: TTensor, + indices: TInt, padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False, @@ -3636,7 +3669,7 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: @torch_op( - ("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal", "_operator::ge") + ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge") ) def aten_ge(self: TReal, other: TReal) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3644,7 +3677,9 @@ def aten_ge(self: TReal, other: TReal) -> BOOL: return op.GreaterOrEqual(self, other) -@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal")) +@torch_op( + ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge") +) def aten_ge_bool(self: BOOL, other: BOOL) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3792,14 +3827,14 @@ def aten_gru_cell( raise NotImplementedError() -@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater", "_operator::gt")) +@torch_op(("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt")) def aten_gt(self: TReal, other: TReal) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" return op.Greater(self, other) -@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater")) +@torch_op(("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt")) def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" # self, other, self > other @@ -4583,14 +4618,14 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::le", "aten::le.Tensor", "_operator::le")) +@torch_op(("aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le")) def aten_le(self: TReal, other: TReal) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" return op.LessOrEqual(self, other) -@torch_op(("aten::le", "aten::le.Tensor", "aten::less_equal")) +@torch_op(("aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le")) def aten_le_bool(self: BOOL, other: BOOL) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -4747,7 +4782,6 @@ def aten_logdet(self: TFloat) -> TFloat: @torch_op( ( "aten::logical_and", - "aten::bitwise_and", "aten::bitwise_and.Tensor", "aten::bitwise_and.Scalar", "aten::bitwise_and.Scalar_Tensor", @@ -4769,7 +4803,6 @@ def aten_logical_not(self: BOOL) -> BOOL: @torch_op( ( "aten::logical_or", - "aten::bitwise_or", "aten::bitwise_or.Tensor", "aten::bitwise_or.Scalar", "aten::bitwise_or.Scalar_Tensor", @@ -4786,7 +4819,6 @@ def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: @torch_op( ( "aten::logical_xor", - "aten::bitwise_xor", "aten::bitwise_xor.Tensor", "aten::bitwise_xor.Scalar", "aten::bitwise_xor.Scalar_Tensor", @@ -4879,14 +4911,14 @@ def aten_lstm_mps_backward( raise NotImplementedError() -@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less", "_operator::lt")) +@torch_op(("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt")) def aten_lt(self: TReal, other: TReal) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" return op.Less(self, other) -@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less")) +@torch_op(("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt")) def aten_lt_bool(self: BOOL, other: BOOL) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -4964,7 +4996,7 @@ def aten_margin_ranking_loss( @torch_op( - ("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor"), + ("aten::masked_fill.Scalar", "aten::masked_fill.Tensor"), traceable=True, ) def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: @@ -6486,9 +6518,7 @@ def aten_positive(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op( - ("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow") -) +@torch_op(("aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow")) def aten_pow(self: TReal, exponent: TTensor) -> TReal: """pow(Tensor self, Tensor exponent) -> Tensor""" @@ -7226,7 +7256,13 @@ def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: @torch_op("aten::scalar_tensor", trace_only=True) -def aten_scalar_tensor(s: float, dtype: int = FLOAT.dtype) -> RealType: +def aten_scalar_tensor( + s: float, + dtype: int = FLOAT.dtype, + layout: str = "", # pylint: disable=unused-argument + device: str = "", # pylint: disable=unused-argument + pin_memory: bool = False, # pylint: disable=unused-argument +) -> RealType: """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" # Set trace_only=True because different if branches return different dtypes @@ -7275,7 +7311,7 @@ def aten_scatter_add( return op.ScatterElements(self, index, src, axis=dim, reduction="add") -@torch_op(("aten::scatter_reduce", "aten::scatter_reduce.two"), trace_only=True) +@torch_op("aten::scatter_reduce.two", trace_only=True) def aten_scatter_reduce( self: TReal, dim: int, # we have to use int here because ScatterElements() will use this attribute @@ -8295,7 +8331,7 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False) -@torch_op("aten::unflatten") +@torch_op("aten::unflatten.int") def aten_unflatten(self: TReal, dim: INT64, sizes: INT64): """unflatten(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)""" @@ -8641,7 +8677,7 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::view") +@torch_op(("aten::view", "aten::_unsafe_view")) def aten_view(self: TTensor, size: IntType) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" @@ -8649,7 +8685,7 @@ def aten_view(self: TTensor, size: IntType) -> TTensor: return op.Reshape(self, size) -@torch_op("aten::view", complex=True) +@torch_op(("aten::view", "aten::_unsafe_view"), complex=True) def aten_view_complex(self: TTensor, size: IntType) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index 6719581f6..bf4746261 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -214,7 +214,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::log_softmax", "aten::special_log_softmax"), trace_only=True) +@torch_op(("aten::log_softmax.int", "aten::special_log_softmax"), trace_only=True) def aten_special_log_softmax( self: TFloatOrBFloat16, dim: int, dtype: int = -1 ) -> TFloatOrBFloat16: @@ -364,7 +364,7 @@ def aten_special_xlog1py(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::xlogy") +@torch_op(("aten::xlogy.Tensor", "aten::xlogy.Scalar_Self", "aten::xlogy.Scalar_Other")) def aten_special_xlogy(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: """special_xlogy(Tensor self, Tensor other) -> Tensor""" diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 1442ba5e9..7eeba0493 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1444,11 +1444,12 @@ def __init__( def __repr__(self) -> str: value_name = self.name if self.name else "anonymous:" + str(id(self)) producer = self.producer() - producer_text = ( - producer.name is not None or "anonymous_node:" + str(id(producer)) - if producer is not None - else None - ) + if producer is None: + producer_text = "None" + elif producer.name is not None: + producer_text = producer.name + else: + producer_text = f"anonymous_node:{id(producer)}" return f"{self.__class__.__name__}({value_name!r}, type={self.type!r}, shape={self.shape}, producer={producer_text}, index={self.index()})" def __str__(self) -> str: @@ -2413,7 +2414,7 @@ def __str__(self) -> str: inputs_text = ",\n".join(str(x) for x in self.inputs) outputs_text = ",\n".join(str(x) for x in self.outputs) attributes_text = ",\n".join( - f"{attr.name}: {attr.type}" + f" = {attr.value}" * (attr.value is None) + f"{attr.name}: {attr.type}" + f" = {attr.value}" * (attr.value is not None) for attr in self.attributes.values() ) if attributes_text: diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index a435d599e..1af6223b1 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -13,6 +13,8 @@ from __future__ import annotations +import functools + __all__ = [ # Tensors "TensorProtoTensor", @@ -50,13 +52,14 @@ "serialize_type_into", "serialize_value_into", "serialize_value", + "SerdeError", ] import collections import logging import os import typing -from typing import Any, List, Mapping, Sequence +from typing import Any, Callable, List, Mapping, Sequence import numpy as np import onnx @@ -70,9 +73,35 @@ logger = logging.getLogger(__name__) +_PLEASE_CONTRIBUTE = ( + "Please contribute by creating a PR at https://github.com/microsoft/onnxscript." +) _FUNCTION_VALUE_INFO_SUPPORTED_VERSION = ( 10 # ONNX IR version where value info in functions was introduced ) +_T = typing.TypeVar("_T", bound=Callable[..., Any]) + + +class SerdeError(RuntimeError): + """Error during serialization or deserialization.""" + + +def _capture_errors(arg_capturer: Callable[..., str]) -> Callable[[_T], _T]: + """Decorator to capture errors and display the stack.""" + + def decorator(func: _T) -> _T: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return func(*args, **kwargs) + except Exception as e: + raise SerdeError( + f"Error calling {func.__name__} with: {arg_capturer(*args, **kwargs)}" + ) from e + + return wrapper # type: ignore + + return decorator def _little_endian_dtype(dtype) -> np.dtype: @@ -98,7 +127,8 @@ def from_proto( | onnx.TensorProto | onnx.AttributeProto | onnx.ValueInfoProto - | onnx.TypeProto, + | onnx.TypeProto + | onnx.FunctionProto, ) -> Any: """Deserialize an ONNX proto message to an IR object.""" if isinstance(proto, onnx.ModelProto): @@ -118,6 +148,8 @@ def from_proto( deserialize_type_proto_for_type(proto), deserialize_type_proto_for_shape(proto), ) + if isinstance(proto, onnx.FunctionProto): + return deserialize_function(proto) raise NotImplementedError( f"Deserialization of {type(proto)} in from_proto is not implemented. " "Use a specific ir.serde.deserialize* function instead." @@ -133,7 +165,8 @@ def to_proto( | _protocols.ReferenceAttributeProtocol | _protocols.TensorProtocol | _protocols.TypeProtocol - | _protocols.GraphViewProtocol, + | _protocols.GraphViewProtocol + | _protocols.FunctionProtocol, ) -> Any: """Serialize an IR object to a proto.""" if isinstance(ir_object, _protocols.ModelProtocol): @@ -154,6 +187,8 @@ def to_proto( return serialize_type_into(onnx.TypeProto(), ir_object) if isinstance(ir_object, _protocols.GraphViewProtocol): return serialize_graph(ir_object) + if isinstance(ir_object, _protocols.FunctionProtocol): + return serialize_function(ir_object) raise NotImplementedError( f"Serialization of {type(ir_object)} in to_proto is not implemented. " "Use a specific ir.serde.serialize* function instead." @@ -509,6 +544,7 @@ def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph: return _deserialize_graph(proto, []) +@_capture_errors(lambda proto, scoped_values: proto.name) def _deserialize_graph( proto: onnx.GraphProto, scoped_values: list[dict[str, _core.Value]] ) -> _core.Graph: @@ -573,6 +609,7 @@ def _deserialize_graph( ) +@_capture_errors(lambda proto: proto.name) def deserialize_function(proto: onnx.FunctionProto) -> _core.Function: inputs = [_core.Input(name) for name in proto.input] values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc] @@ -609,6 +646,7 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function: ) +@_capture_errors(lambda proto, value: str(proto)) def deserialize_value_info_proto( proto: onnx.ValueInfoProto, value: _core.Value | None ) -> _core.Value: @@ -623,6 +661,7 @@ def deserialize_value_info_proto( return value +@_capture_errors(str) def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None: if proto.HasField("tensor_type"): if (shape_proto := _get_field(proto.tensor_type, "shape")) is None: @@ -655,11 +694,12 @@ def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | Non return deserialize_type_proto_for_shape(elem_type) if proto.HasField("map_type"): # TODO(justinchuby): Do we need to support map types? - raise NotImplementedError("Map types are not supported yet") + raise NotImplementedError(f"Map types are not supported yet. {_PLEASE_CONTRIBUTE}") return None +@_capture_errors(str) def deserialize_type_proto_for_type( proto: onnx.TypeProto, ) -> _protocols.TypeProtocol | None: @@ -690,11 +730,12 @@ def deserialize_type_proto_for_type( return _core.OptionalType(nested_type, denotation=denotation) if proto.HasField("map_type"): # TODO(justinchuby): Do we need to support map types? - raise NotImplementedError("Map types are not supported yet") + raise NotImplementedError(f"Map types are not supported yet. {_PLEASE_CONTRIBUTE}") return None +@_capture_errors(str) def deserialize_dimension( proto: onnx.TensorShapeProto.Dimension, ) -> tuple[int | _core.SymbolicDim, str | None]: @@ -717,6 +758,7 @@ def deserialize_dimension( return _core.SymbolicDim(None), denotation +@_capture_errors(lambda proto, base_path: proto.name) def deserialize_tensor( proto: onnx.TensorProto, base_path: str | os.PathLike = "" ) -> _protocols.TensorProtocol: @@ -760,6 +802,7 @@ def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr | _core.RefA return _deserialize_attribute(proto, []) +@_capture_errors(lambda proto, scoped_values: str(proto)) def _deserialize_attribute( proto: onnx.AttributeProto, scoped_values: list[dict[str, _core.Value]] ) -> _core.Attr | _core.RefAttr: @@ -803,9 +846,13 @@ def _deserialize_attribute( doc_string=doc_string, ) if type_ == _enums.AttributeType.SPARSE_TENSOR: - raise NotImplementedError("Sparse tensors are not supported yet") + raise NotImplementedError( + f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}" + ) if type_ == _enums.AttributeType.SPARSE_TENSORS: - raise NotImplementedError("Sparse tensors are not supported yet") + raise NotImplementedError( + f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}" + ) if type_ == _enums.AttributeType.TYPE_PROTO: ir_type = deserialize_type_proto_for_type(proto.tp) shape = deserialize_type_proto_for_shape(proto.tp) @@ -828,6 +875,7 @@ def deserialize_node(proto: onnx.NodeProto) -> _core.Node: return _deserialize_node(proto, scoped_values=[], value_info={}) +@_capture_errors(lambda proto, scoped_values, value_info: str(proto)) def _deserialize_node( proto: onnx.NodeProto, scoped_values: list[dict[str, _core.Value]], @@ -936,6 +984,12 @@ def serialize_model(model: _protocols.ModelProtocol) -> onnx.ModelProto: return serialize_model_into(onnx.ModelProto(), from_=model) +@_capture_errors( + lambda model_proto, from_: ( + f"ir_version={from_.ir_version}, producer_name={from_.producer_name}, " + f"producer_version={from_.producer_version}, domain={from_.domain}, " + ) +) def serialize_model_into( model_proto: onnx.ModelProto, from_: _protocols.ModelProtocol ) -> onnx.ModelProto: @@ -1086,6 +1140,13 @@ def serialize_graph( return graph_proto +@_capture_errors( + lambda graph_proto, from_: ( + f"name={from_.name}, doc_string={from_.doc_string}, " + f"len(inputs)={len(from_.inputs)}, len(initializers)={len(from_.initializers)}, " + f"len(nodes)={len(from_)}, len(outputs)={len(from_.outputs)}, metadata_props={from_.metadata_props}" + ) +) def serialize_graph_into( graph_proto: onnx.GraphProto, from_: _protocols.GraphProtocol | _protocols.GraphViewProtocol, @@ -1140,6 +1201,7 @@ def serialize_function( return function_proto +@_capture_errors(lambda function_proto, from_, create_value_info: repr(from_)) def serialize_function_into( function_proto: onnx.FunctionProto, from_: _protocols.FunctionProtocol, @@ -1205,6 +1267,7 @@ def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto: return node_proto +@_capture_errors(lambda node_proto, from_: repr(from_)) def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None: node_proto.op_type = from_.op_type if from_.domain: @@ -1248,6 +1311,7 @@ def serialize_tensor(tensor: _protocols.TensorProtocol) -> onnx.TensorProto: return tensor_proto +@_capture_errors(lambda tensor_proto, from_: repr(from_)) def serialize_tensor_into( tensor_proto: onnx.TensorProto, from_: _protocols.TensorProtocol ) -> None: @@ -1289,6 +1353,7 @@ def serialize_attribute(attribute: _protocols.AttributeProtocol) -> onnx.Attribu return attribute_proto +@_capture_errors(lambda attribute_proto, from_: repr(from_)) def serialize_attribute_into( attribute_proto: onnx.AttributeProto, from_: _protocols.AttributeProtocol ) -> None: @@ -1344,9 +1409,13 @@ def _fill_in_value_for_attribute( serialize_graph_into(attribute_proto.graphs.add(), graph) attribute_proto.type = onnx.AttributeProto.GRAPHS elif type_ == _enums.AttributeType.SPARSE_TENSOR: - raise NotImplementedError("Sparse tensors are not supported yet") + raise NotImplementedError( + f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}" + ) elif type_ == _enums.AttributeType.SPARSE_TENSORS: - raise NotImplementedError("Sparse tensors are not supported yet") + raise NotImplementedError( + f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}" + ) elif type_ == _enums.AttributeType.TYPE_PROTO: # value: _core.TypeAndShape if value.type is not None: @@ -1369,6 +1438,7 @@ def _fill_in_value_for_attribute( raise TypeError(f"Unsupported attribute type: {type_}") +@_capture_errors(lambda attribute_proto, from_: repr(from_)) def serialize_reference_attribute_into( attribute_proto: onnx.AttributeProto, from_: _protocols.ReferenceAttributeProtocol ) -> None: @@ -1392,6 +1462,7 @@ def serialize_value(value: _protocols.ValueProtocol, *, name: str = "") -> onnx. return value_info_proto +@_capture_errors(lambda value_info_proto, from_: repr(from_)) def serialize_value_into( value_info_proto: onnx.ValueInfoProto, from_: _protocols.ValueProtocol, @@ -1420,6 +1491,7 @@ def serialize_value_into( value_info_proto.doc_string = from_.doc_string +@_capture_errors(lambda type_proto, from_: repr(from_)) def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtocol) -> None: if from_.denotation: type_proto.denotation = from_.denotation @@ -1439,6 +1511,7 @@ def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtoc raise TypeError(f"Unsupported type: {from_}") +@_capture_errors(lambda type_proto, from_: repr(from_)) def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None: value_field = type_proto.WhichOneof("value") tensor_type = getattr(type_proto, value_field) @@ -1454,6 +1527,7 @@ def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProt serialize_dimension_into(tensor_type.shape.dim.add(), dim, denotation) +@_capture_errors(lambda dim_proto, dim, denotation: repr(dim_proto)) def serialize_dimension_into( dim_proto: onnx.TensorShapeProto.Dimension, dim: int | _protocols.SymbolicDimProtocol, diff --git a/onnxscript/values.py b/onnxscript/values.py index 8e36cdfa2..40e030262 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -527,6 +527,9 @@ def __call__(self, *args, **kwargs): return evaluator.default().eval_function(self, args, kwargs) + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.function!r})" + def param_schemas(self) -> tuple[ParamSchema, ...]: """Returns the parameter schemas of this function.""" if self._param_schemas is not None: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index dc35df650..ab3e204af 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -705,7 +705,7 @@ def _where_input_wrangler( TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor), TorchLibOpInfo("bmm", core_ops.aten_bmm), TorchLibOpInfo("broadcast_to", core_ops.aten_broadcast_to), - TorchLibOpInfo("cat", core_ops.aten_cat).skip( + TorchLibOpInfo("cat", core_ops.aten_cat, trace_only=True).skip( matcher=lambda sample: sample.input[0].equal(torch.tensor([])), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), @@ -739,11 +739,11 @@ def _where_input_wrangler( ), TorchLibOpInfo("clone", core_ops.aten_clone), TorchLibOpInfo("complex", core_ops.aten_complex, trace_only=True), - TorchLibOpInfo("concat", core_ops.aten_concat).skip( + TorchLibOpInfo("concat", core_ops.aten_cat, trace_only=True).skip( matcher=lambda sample: sample.input[0].equal(torch.tensor([])), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), - TorchLibOpInfo("concatenate", core_ops.aten_concatenate).skip( + TorchLibOpInfo("concatenate", core_ops.aten_cat, trace_only=True).skip( matcher=lambda sample: sample.input[0].equal(torch.tensor([])), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ),