Skip to content

Commit

Permalink
[torchlib] Make binary comparison ops and more traceable (#1957)
Browse files Browse the repository at this point in the history
ge, gt, le, lt
  • Loading branch information
justinchuby authored Nov 19, 2024
1 parent 35b20fe commit 5c62178
Showing 1 changed file with 42 additions and 33 deletions.
75 changes: 42 additions & 33 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,14 +1585,14 @@ def aten_cdist(
raise NotImplementedError()


@torch_op("aten::ceil")
@torch_op("aten::ceil", traceable=True)
def aten_ceil(self: TFloat) -> TFloat:
"""ceil(Tensor self) -> Tensor"""

return op.Ceil(self)


@torch_op("math::ceil")
@torch_op("math::ceil", traceable=True)
def python_math_ceil(self: TFloat) -> TInt:
"""ceil(Tensor self) -> Tensor"""
ceil = op.Ceil(self)
Expand Down Expand Up @@ -1764,13 +1764,6 @@ def aten_combinations(
raise NotImplementedError()


@torch_op("aten::complex", private=True)
def _aten_complex(real: TFloat, imag: TFloat) -> TFloat:
"""Non-broadcasting complex constructor."""

return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1)


@torch_op("aten::complex", trace_only=True)
def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
"""complex(Tensor real, Tensor imag) -> Tensor"""
Expand All @@ -1780,7 +1773,7 @@ def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
real = op.Expand(real, broadcasted_shape)
imag = op.Expand(imag, broadcasted_shape)

return _aten_complex(real, imag)
return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1)


@torch_op("aten::conj", trace_only=True)
Expand All @@ -1790,7 +1783,6 @@ def aten_conj(self: TTensor) -> TTensor:
return op.Identity(self)


@torch_op("aten::conj", complex=True, private=True)
def _complex_conjugate(self: TFloat) -> TFloat:
zero = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
Expand All @@ -1809,8 +1801,6 @@ def _complex_conjugate(self: TFloat) -> TFloat:
def aten_conj_complex(self: TFloat) -> TFloat:
"""conj(Tensor(a) self) -> Tensor(a)"""

# TODO(#834): Allow calling scripted functions from other
# scripted functions and remove trace only.
return _complex_conjugate(self)


Expand Down Expand Up @@ -3273,7 +3263,7 @@ def aten_empty_quantized(
raise NotImplementedError()


@torch_op("aten::empty_strided")
@torch_op("aten::empty_strided", traceable=True)
def aten_empty_strided(
size: INT64,
stride: INT64,
Expand All @@ -3290,14 +3280,14 @@ def aten_empty_strided(
return op.Expand(zero, size)


@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar", "_operator::eq"))
@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar", "_operator::eq"), traceable=True)
def aten_eq(self: TTensor, other: TTensor) -> BOOL:
"""eq.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Equal(self, other)


@torch_op("aten::equal")
@torch_op("aten::equal", traceable=True)
def aten_equal(self: TTensor, other: TTensor) -> BOOL:
"""equal(Tensor self, Tensor other) -> bool"""

Expand Down Expand Up @@ -3759,7 +3749,8 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:


@torch_op(
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge")
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"),
traceable=True,
)
def aten_ge(self: TReal, other: TReal) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -3768,7 +3759,8 @@ def aten_ge(self: TReal, other: TReal) -> BOOL:


@torch_op(
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge")
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"),
traceable=True,
)
def aten_ge_bool(self: BOOL, other: BOOL) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand Down Expand Up @@ -3904,14 +3896,20 @@ def aten_gru_cell(
raise NotImplementedError()


@torch_op(("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"))
@torch_op(
("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"),
traceable=True,
)
def aten_gt(self: TReal, other: TReal) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Greater(self, other)


@torch_op(("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"))
@torch_op(
("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"),
traceable=True,
)
def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""
# self, other, self > other
Expand Down Expand Up @@ -3949,7 +3947,7 @@ def aten_hardshrink_backward(
raise NotImplementedError()


@torch_op("aten::heaviside")
@torch_op("aten::heaviside", traceable=True)
def aten_heaviside(self: TReal, values: TReal) -> TReal:
"""heaviside(Tensor self, Tensor values) -> Tensor"""

Expand Down Expand Up @@ -4695,14 +4693,20 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(
("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"),
traceable=True,
)
def aten_le(self: TReal, other: TReal) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.LessOrEqual(self, other)


@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(
("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"),
traceable=True,
)
def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5002,14 +5006,20 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


@torch_op(("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"))
@torch_op(
("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"),
traceable=True,
)
def aten_lt(self: TReal, other: TReal) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Less(self, other)


@torch_op(("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"))
@torch_op(
("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"),
traceable=True,
)
def aten_lt_bool(self: BOOL, other: BOOL) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5051,9 +5061,6 @@ def aten_mH(self: TRealOrUInt8) -> TRealOrUInt8:
def aten_mH_complex(self: TFloat) -> TFloat:
"""mH(Tensor(a) self) -> Tensor(a)"""

# TODO(#834): Allow calling scripted functions from other
# scripted functions and remove trace only.

# c is the last dimension being the real and imaginary parts
trasposed = op.Einsum(self, equation="...ijc->...jic")
return _complex_conjugate(trasposed)
Expand Down Expand Up @@ -6218,14 +6225,14 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne"))
@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne"), traceable=True)
def aten_ne(self: TReal, other: TReal) -> BOOL:
"""ne.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Not(op.Equal(self, other))


@torch_op(("aten::neg", "_operator::neg"))
@torch_op(("aten::neg", "_operator::neg"), traceable=True)
def aten_neg(self: TReal) -> TReal:
"""neg(Tensor self) -> Tensor"""

Expand Down Expand Up @@ -7067,7 +7074,7 @@ def aten_real(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::reciprocal")
@torch_op("aten::reciprocal", traceable=True)
def aten_reciprocal(self: TFloat) -> TFloat:
"""reciprocal(Tensor self) -> Tensor"""

Expand All @@ -7086,7 +7093,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), traceable=True)
def aten_remainder(self: TFloat, other: TFloat) -> TFloat:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand All @@ -7099,7 +7106,9 @@ def aten_remainder(self: TFloat, other: TFloat) -> TFloat:
return op.Sub(self, op.Mul(rounded_quotient, other))


@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"))
@torch_op(
("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), traceable=True
)
def aten_remainder_int(self: TInt, other: TInt) -> TInt:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down

0 comments on commit 5c62178

Please sign in to comment.