diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index c9fb79f61..56b6a0dc8 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -550,9 +550,6 @@ def aten_arange( ) -> TensorType: """arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. - if dtype == -1: zero = op.CastLike(0.0, end) one = op.CastLike(1.0, end) @@ -1229,6 +1226,7 @@ def aten_bitwise_and(self: TInt, other: TInt) -> TInt: "aten::bitwise_left_shift.Tensor", "aten::bitwise_left_shift.Tensor_Scalar", "aten::bitwise_left_shift.Scalar_Tensor", + "_operator::__lshift__", ), traceable=True, ) @@ -1248,6 +1246,7 @@ def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16: "aten::bitwise_left_shift.Tensor", "aten::bitwise_left_shift.Tensor_Scalar", "aten::bitwise_left_shift.Scalar_Tensor", + "_operator::__lshift__", ), traceable=True, ) @@ -1267,6 +1266,7 @@ def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32: "aten::bitwise_left_shift.Tensor", "aten::bitwise_left_shift.Tensor_Scalar", "aten::bitwise_left_shift.Scalar_Tensor", + "_operator::__lshift__", ), traceable=True, ) @@ -1286,6 +1286,7 @@ def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64: "aten::bitwise_left_shift.Tensor", "aten::bitwise_left_shift.Tensor_Scalar", "aten::bitwise_left_shift.Scalar_Tensor", + "_operator::__lshift__", ), traceable=True, ) @@ -1329,6 +1330,7 @@ def aten_bitwise_or(self: TInt, other: TInt) -> TInt: "aten::bitwise_right_shift.Tensor", "aten::bitwise_right_shift.Tensor_Scalar", "aten::bitwise_right_shift.Scalar_Tensor", + "_operator::__rshift__", ) ) def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: @@ -1358,6 +1360,7 @@ def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: "aten::bitwise_right_shift.Tensor", "aten::bitwise_right_shift.Tensor_Scalar", "aten::bitwise_right_shift.Scalar_Tensor", + "_operator::__rshift__", ) ) def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: @@ -1387,6 +1390,7 @@ def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: "aten::bitwise_right_shift.Tensor", "aten::bitwise_right_shift.Tensor_Scalar", "aten::bitwise_right_shift.Scalar_Tensor", + "_operator::__rshift__", ) ) def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: @@ -1419,6 +1423,7 @@ def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: "aten::bitwise_right_shift.Tensor", "aten::bitwise_right_shift.Tensor_Scalar", "aten::bitwise_right_shift.Scalar_Tensor", + "_operator::__rshift__", ) ) def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: @@ -3606,30 +3611,35 @@ def aten_from_file( @torch_op("aten::full", trace_only=True) def aten_full( - size: INT64, - fill_value: FLOAT, + size: Union[INT64, INT32], + fill_value: TensorType, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", pin_memory: bool = False, -): +) -> TensorType: """full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - size = op.Cast(size, to=INT64.dtype) if dtype != -1: fill_value = op.Cast(fill_value, to=dtype) + if isinstance(size, list) and size == []: + # TODO(justinchuby): Handle empty list better than using isinstance + # size can be empty, meaning a scalar + return fill_value + + size = op.Cast(size, to=INT64.dtype) return op.Expand(fill_value, size) @torch_op("aten::full_like", trace_only=True) def aten_full_like( - self: TTensor, - fill_value: TTensor, + self: TensorType, + fill_value: TensorType, dtype: int = -1, layout: str = "", device: str = "", pin_memory: bool = False, -) -> TTensor: +) -> TensorType: """full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" if dtype == -1: @@ -4715,11 +4725,17 @@ def aten_linear_backward( @torch_op("aten::linspace", trace_only=True) def aten_linspace( - start: TFloat, end: TFloat, steps: int, dtype: int = FLOAT.dtype + start: TFloat, + end: TFloat, + steps: int, + dtype: int = FLOAT.dtype, + layout: str = "", + device: str = "", + pin_memory: bool = False, ) -> TensorType: """linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - if dtype == -1: + if dtype == -1 or dtype is None: dtype = FLOAT.dtype # Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896 @@ -4743,14 +4759,14 @@ def aten_linspace( ) -@torch_op("aten::log") +@torch_op("aten::log", traceable=True) def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """log(Tensor self) -> Tensor""" return op.Log(self) -@torch_op("aten::log10") +@torch_op("aten::log10", traceable=True) def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """log10(Tensor self) -> Tensor""" @@ -4764,21 +4780,21 @@ def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16: return op.Log(op.Add(self, 1.0)) -@torch_op("aten::log2") +@torch_op("aten::log2", traceable=True) def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """log2(Tensor self) -> Tensor""" return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self)) -@torch_op("aten::logaddexp") +@torch_op("aten::logaddexp", traceable=True) def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: """logaddexp(Tensor self, Tensor other) -> Tensor""" return op.Log(op.Add(op.Exp(self), op.Exp(other))) -@torch_op("aten::logaddexp2") +@torch_op("aten::logaddexp2", traceable=True) def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: """logaddexp2(Tensor self, Tensor other) -> Tensor""" two = op.CastLike(2.0, self) @@ -4811,7 +4827,7 @@ def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: return result -@torch_op("aten::logdet") +@torch_op("aten::logdet", traceable=True) def aten_logdet(self: TFloat) -> TFloat: """logdet(Tensor self) -> Tensor""" @@ -4824,7 +4840,8 @@ def aten_logdet(self: TFloat) -> TFloat: "aten::bitwise_and.Tensor", "aten::bitwise_and.Scalar", "aten::bitwise_and.Scalar_Tensor", - ) + ), + traceable=True, ) def aten_logical_and(self: BOOL, other: BOOL) -> BOOL: """logical_and(Tensor self, Tensor other) -> Tensor""" @@ -4832,7 +4849,7 @@ def aten_logical_and(self: BOOL, other: BOOL) -> BOOL: return op.And(self, other) -@torch_op(("aten::logical_not", "aten::bitwise_not")) +@torch_op(("aten::logical_not", "aten::bitwise_not"), traceable=True) def aten_logical_not(self: BOOL) -> BOOL: """logical_not(Tensor self) -> Tensor""" @@ -4863,7 +4880,8 @@ def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: "aten::bitwise_xor.Tensor", "aten::bitwise_xor.Scalar", "aten::bitwise_xor.Scalar_Tensor", - ) + ), + traceable=True, ) def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: """logical_xor(Tensor self, Tensor other) -> Tensor""" @@ -4912,12 +4930,6 @@ def aten_logsumexp(self: TFloat, dim: INT64, keepdim: int = False) -> TFloat: return result -def aten_lshift(self: TensorType, other: TensorType) -> TensorType: - """__lshift__.Tensor(Tensor self, Tensor other) -> Tensor""" - - raise NotImplementedError() - - def aten_lstm_cell( input: TensorType, hx: Sequence[TensorType], @@ -6226,7 +6238,7 @@ def aten_new_empty_strided( def aten_new_full( self: TTensor, size: INT64, - fill_value: TTensor, + fill_value: TensorType, dtype: int = -1, layout: str = "", device: str = "", @@ -7308,12 +7320,6 @@ def aten_rrelu( raise NotImplementedError() -def aten_rshift(self: TensorType, other: TensorType) -> TensorType: - """__rshift__.Tensor(Tensor self, Tensor other) -> Tensor""" - - raise NotImplementedError() - - @torch_op("aten::rsqrt") def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """rsqrt(Tensor self) -> Tensor"""