diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index dfc0e7882..475458892 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2236,23 +2236,13 @@ def aten_cov( raise NotImplementedError() -@torch_op("aten::cross") +@torch_op(("aten::cross", "aten::linalg_cross")) def aten_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor: """cross(Tensor self, Tensor other, int? dim=None) -> Tensor""" - zero = op.Constant(value_ints=[0]) - one = op.Constant(value_ints=[1]) - two = op.Constant(value_ints=[2]) - three = op.Constant(value_ints=[3]) - axes = op.Expand(dim, op.Constant(value_ints=[1])) - # Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073 - a1 = op.Slice(self, zero, one, axes) - a2 = op.Slice(self, one, two, axes) - a3 = op.Slice(self, two, three, axes) - b1 = op.Slice(other, zero, one, axes) - b2 = op.Slice(other, one, two, axes) - b3 = op.Slice(other, two, three, axes) + a1, a2, a3 = op.Split(self, axis=dim, num_outputs=3) + b1, b2, b3 = op.Split(other, axis=dim, num_outputs=3) # Broadcasting is implicitly supported by Mul c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2)) c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3)) @@ -3571,7 +3561,7 @@ def aten_fmin(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::fmod") +@torch_op(("aten::fmod.Tensor", "aten::fmod.Scalar")) def aten_fmod(self: TRealOrUInt8, other: TRealOrUInt8) -> TRealOrUInt8: """fmod.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -4659,7 +4649,7 @@ def aten_le(self: TReal, other: TReal) -> BOOL: return op.LessOrEqual(self, other) -@torch_op(("aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le")) +@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le")) def aten_le_bool(self: BOOL, other: BOOL) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -4672,10 +4662,17 @@ def aten_le_bool(self: BOOL, other: BOOL) -> BOOL: return op.Or(other, op.Not(self)) -def aten_lerp(self: TensorType, end: TensorType, weight: TensorType) -> TensorType: +@torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar")) +def aten_lerp(self: TTensor, end: TTensor, weight: TTensor) -> TTensor: """lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor""" - raise NotImplementedError() + weight = op.CastLike(weight, self) + diff = op.Sub(end, self) + return op.Where( + op.Less(weight, 0.5), + op.Add(self, op.Mul(weight, diff)), + op.Sub(end, op.Mul(diff, op.Sub(1.0, weight))), + ) def aten_lgamma(self: TensorType) -> TensorType: @@ -5619,10 +5616,11 @@ def aten_multiply(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::mv") def aten_mv(self: TensorType, vec: TensorType) -> TensorType: """mv(Tensor self, Tensor vec) -> Tensor""" - raise NotImplementedError() + return op.MatMul(self, vec) def aten_mvlgamma(self: TensorType, p: int) -> TensorType: @@ -7011,7 +7009,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op("aten::remainder") +@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar")) def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -7024,7 +7022,7 @@ def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrB return op.Sub(self, op.Mul(rounded_quotient, other)) -@torch_op("aten::remainder") +@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar")) def aten_remainder_int(self: TInt, other: TInt) -> TInt: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -8533,10 +8531,11 @@ def aten_unsafe_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType raise NotImplementedError() -def aten_unsafe_split(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType: +@torch_op(("aten::unsafe_split", "aten::unsafe_split.Tensor")) +def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> Sequence[TTensor]: """unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]""" - raise NotImplementedError() + return op.SplitToSequence(self, split_size, axis=dim) def aten_unsafe_split_with_sizes( diff --git a/onnxscript/function_libs/torch_lib/ops/linalg.py b/onnxscript/function_libs/torch_lib/ops/linalg.py index 0dd8eced4..ebc07b5d3 100644 --- a/onnxscript/function_libs/torch_lib/ops/linalg.py +++ b/onnxscript/function_libs/torch_lib/ops/linalg.py @@ -17,7 +17,7 @@ from onnxscript import BOOL, FLOAT, INT64 from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op -from onnxscript.function_libs.torch_lib.tensor_typing import TFloat +from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TTensor from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -44,9 +44,10 @@ def aten_linalg_cond(self: TensorType, p: Optional[float] = None) -> TensorType: raise NotImplementedError() -def aten_linalg_cross(self: TensorType, other: TensorType, dim: int = -1) -> TensorType: +def aten_linalg_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor: """linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor""" + # Same implementation as aten_cross raise NotImplementedError() diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b4f3c5701..773c19f1d 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -900,6 +900,11 @@ def _where_input_wrangler( TorchLibOpInfo("log", core_ops.aten_log), TorchLibOpInfo("le", core_ops.aten_le), TorchLibOpInfo("le_bool", core_ops.aten_le_bool), + TorchLibOpInfo( + "lerp", + core_ops.aten_lerp, + tolerance={torch.float16: (2e-3, 2e-1)}, + ), TorchLibOpInfo("log10", core_ops.aten_log10), TorchLibOpInfo("log1p", core_ops.aten_log1p), TorchLibOpInfo( @@ -1020,6 +1025,11 @@ def _where_input_wrangler( TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True), TorchLibOpInfo("mul", core_ops.aten_mul), TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True), + TorchLibOpInfo( + "mv", + core_ops.aten_mv, + tolerance={torch.float16: (3e-2, 1e-2)}, + ), TorchLibOpInfo("narrow", core_ops.aten_narrow), TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout), TorchLibOpInfo("ne", core_ops.aten_ne),