Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Implement missing operators (set1) #1706

Merged
merged 12 commits into from
Jul 17, 2024
42 changes: 20 additions & 22 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,23 +2236,13 @@
raise NotImplementedError()


@torch_op("aten::cross")
@torch_op(("aten::cross", "aten::linalg_cross"))
def aten_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:
"""cross(Tensor self, Tensor other, int? dim=None) -> Tensor"""

zero = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
two = op.Constant(value_ints=[2])
three = op.Constant(value_ints=[3])
axes = op.Expand(dim, op.Constant(value_ints=[1]))

# Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073
a1 = op.Slice(self, zero, one, axes)
a2 = op.Slice(self, one, two, axes)
a3 = op.Slice(self, two, three, axes)
b1 = op.Slice(other, zero, one, axes)
b2 = op.Slice(other, one, two, axes)
b3 = op.Slice(other, two, three, axes)
a1, a2, a3 = op.Split(self, axis=dim, num_outputs=3)
b1, b2, b3 = op.Split(other, axis=dim, num_outputs=3)

Check warning on line 2245 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2244-L2245

Added lines #L2244 - L2245 were not covered by tests
# Broadcasting is implicitly supported by Mul
c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2))
c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3))
Expand Down Expand Up @@ -3571,7 +3561,7 @@
raise NotImplementedError()


@torch_op("aten::fmod")
@torch_op(("aten::fmod.Tensor", "aten::fmod.Scalar"))
def aten_fmod(self: TRealOrUInt8, other: TRealOrUInt8) -> TRealOrUInt8:
"""fmod.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -4659,7 +4649,7 @@
return op.LessOrEqual(self, other)


@torch_op(("aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand All @@ -4672,10 +4662,16 @@
return op.Or(other, op.Not(self))


def aten_lerp(self: TensorType, end: TensorType, weight: TensorType) -> TensorType:
@torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar"))
def aten_lerp(self: TReal, end: TReal, weight: TReal) -> TReal:
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
"""lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"""

raise NotImplementedError()
diff = op.Sub(end, self)
return op.Where(

Check warning on line 4670 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4669-L4670

Added lines #L4669 - L4670 were not covered by tests
op.Less(weight, 0.5),
op.Add(self, op.Mul(weight, diff)),
op.Sub(end, op.Mul(diff, op.Sub(1.0, weight)))
)


def aten_lgamma(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -5619,10 +5615,11 @@
raise NotImplementedError()


@torch_op("aten::mv")
def aten_mv(self: TensorType, vec: TensorType) -> TensorType:
"""mv(Tensor self, Tensor vec) -> Tensor"""

raise NotImplementedError()
return op.MatMul(self, vec)

Check warning on line 5622 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L5622

Added line #L5622 was not covered by tests


def aten_mvlgamma(self: TensorType, p: int) -> TensorType:
Expand Down Expand Up @@ -7011,7 +7008,7 @@
raise NotImplementedError()


@torch_op("aten::remainder")
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand All @@ -7024,7 +7021,7 @@
return op.Sub(self, op.Mul(rounded_quotient, other))


@torch_op("aten::remainder")
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
def aten_remainder_int(self: TInt, other: TInt) -> TInt:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -8533,10 +8530,11 @@
raise NotImplementedError()


def aten_unsafe_split(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType:
@torch_op(("aten::unsafe_split", "aten::unsafe_split.Tensor"))
def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> Sequence[TTensor]:
"""unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"""

raise NotImplementedError()
return op.SplitToSequence(self, split_size, axis=dim)

Check warning on line 8537 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8537

Added line #L8537 was not covered by tests


def aten_unsafe_split_with_sizes(
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from onnxscript import BOOL, FLOAT, INT64
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TTensor
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

Expand All @@ -44,9 +44,9 @@ def aten_linalg_cond(self: TensorType, p: Optional[float] = None) -> TensorType:
raise NotImplementedError()


def aten_linalg_cross(self: TensorType, other: TensorType, dim: int = -1) -> TensorType:
"""linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor"""
def aten_linalg_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved

# Same implementation as aten_cross
raise NotImplementedError()


Expand Down
5 changes: 5 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,7 @@ def _where_input_wrangler(
TorchLibOpInfo("isneginf", core_ops.aten_isneginf),
TorchLibOpInfo("isposinf", core_ops.aten_isposinf),
TorchLibOpInfo("lift_fresh_copy", core_ops.aten_lift_fresh_copy),
TorchLibOpInfo("linalg.cross", linalg_ops.aten_linalg_cross),
TorchLibOpInfo("linalg.det", linalg_ops.aten_linalg_det),
TorchLibOpInfo(
"linalg.vector_norm",
Expand Down Expand Up @@ -900,6 +901,7 @@ def _where_input_wrangler(
TorchLibOpInfo("log", core_ops.aten_log),
TorchLibOpInfo("le", core_ops.aten_le),
TorchLibOpInfo("le_bool", core_ops.aten_le_bool),
TorchLibOpInfo("lerp", core_ops.aten_lerp),
TorchLibOpInfo("log10", core_ops.aten_log10),
TorchLibOpInfo("log1p", core_ops.aten_log1p),
TorchLibOpInfo(
Expand Down Expand Up @@ -1020,6 +1022,7 @@ def _where_input_wrangler(
TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True),
TorchLibOpInfo("mul", core_ops.aten_mul),
TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True),
TorchLibOpInfo("mv", core_ops.aten_mv),
TorchLibOpInfo("narrow", core_ops.aten_narrow),
TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout),
TorchLibOpInfo("ne", core_ops.aten_ne),
Expand Down Expand Up @@ -2390,6 +2393,7 @@ def _where_input_wrangler(
"imag",
"isfinite",
"le",
"lerp",
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
"lgamma",
"log",
"log10",
Expand All @@ -2399,6 +2403,7 @@ def _where_input_wrangler(
"maximum",
"minimum",
"mul",
"mv",
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
"ne",
"neg",
"nextafter",
Expand Down
Loading