Skip to content

Commit

Permalink
Add Op(built-in function ops) | feat(torchlib) (#1135)
Browse files Browse the repository at this point in the history
Fix microsoft/onnx-converters-private#190

With pytorch/pytorch#112758, this PR moves
built-in function ops mapping into torchlib.
NOTE: module name of `operator.add` is _operator
  • Loading branch information
titaiwangms authored Nov 7, 2023
1 parent b0147d8 commit 662af2a
Showing 1 changed file with 33 additions and 12 deletions.
45 changes: 33 additions & 12 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def aten__softmax(
return aten_softmax_no_dtype(self, dim)


@torch_op("aten::abs")
@torch_op(("aten::abs", "_operator::abs"))
def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8:
"""abs(Tensor self) -> Tensor"""

Expand Down Expand Up @@ -158,7 +158,7 @@ def aten_acosh(self: TFloat) -> TFloat:
return op.Acosh(self)


@torch_op(("aten::add", "aten::add.Tensor"))
@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"))
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision
Expand Down Expand Up @@ -1163,6 +1163,7 @@ def aten_binomial(
"aten::bitwise_and.Tensor",
"aten::bitwise_and.Scalar",
"aten::bitwise_and.Scalar_Tensor",
"_operator::and_",
)
)
def aten_bitwise_and(self: TInt, other: TInt) -> TInt:
Expand Down Expand Up @@ -1234,6 +1235,7 @@ def aten_bitwise_not(self: TInt) -> TInt:
"aten::bitwise_or.Tensor",
"aten::bitwise_or.Scalar",
"aten::bitwise_or.Scalar_Tensor",
"_operator::or_",
)
)
def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
Expand Down Expand Up @@ -1443,6 +1445,13 @@ def aten_ceil(self: TFloat) -> TFloat:
return op.Ceil(self)


@torch_op("math::ceil")
def python_math_ceil(self: TFloat) -> TInt:
"""ceil(Tensor self) -> Tensor"""
ceil = op.Ceil(self)
return op.Cast(ceil, to=INT64.dtype)


def aten_chain_matmul(matrices: Sequence[TensorType]) -> TensorType:
"""chain_matmul(Tensor[] matrices) -> Tensor"""

Expand Down Expand Up @@ -2617,6 +2626,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
"aten::div.Scalar_mode",
"aten::divide",
"aten::true_divide",
"_operator::truediv",
)
)
def aten_div(self: TFloat, other: TFloat) -> TFloat:
Expand Down Expand Up @@ -3372,7 +3382,14 @@ def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
return op.Floor(self)


@torch_op("aten::floor_divide")
@torch_op("math::floor")
def python_math_floor(self: TFloatOrBFloat16) -> TInt:
"""floor(Tensor self) -> Tensor"""
floor = op.Floor(self)
return op.Cast(floor, to=INT64.dtype)


@torch_op(("aten::floor_divide", "_operator::floordiv"))
def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3514,7 +3531,9 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal"))
@torch_op(
("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal", "_operator::ge")
)
def aten_ge(self: TReal, other: TReal) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3670,7 +3689,7 @@ def aten_gru_cell(
raise NotImplementedError()


@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater"))
@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater", "_operator::gt"))
def aten_gt(self: TReal, other: TReal) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -4382,7 +4401,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::le", "aten::le.Tensor"))
@torch_op(("aten::le", "aten::le.Tensor", "_operator::le"))
def aten_le(self: TReal, other: TReal) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -4686,7 +4705,7 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less"))
@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less", "_operator::lt"))
def aten_lt(self: TReal, other: TReal) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5288,7 +5307,7 @@ def aten_msort(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::mul", "aten::mul.Tensor"))
@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"))
def aten_mul(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""
# FIXME(titaiwang): get rid of this when we have type_promotion
Expand Down Expand Up @@ -5739,14 +5758,14 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor"))
@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne"))
def aten_ne(self: TReal, other: TReal) -> BOOL:
"""ne.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Not(op.Equal(self, other))


@torch_op("aten::neg")
@torch_op(("aten::neg", "_operator::neg"))
def aten_neg(self: TReal) -> TReal:
"""neg(Tensor self) -> Tensor"""

Expand Down Expand Up @@ -6126,7 +6145,9 @@ def aten_positive(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar"))
@torch_op(
("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow")
)
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
"""pow(Tensor self, Tensor exponent) -> Tensor"""

Expand Down Expand Up @@ -7406,7 +7427,7 @@ def aten_stft(
return result


@torch_op(("aten::sub", "aten::sub.Tensor"))
@torch_op(("aten::sub", "aten::sub.Tensor", "_operator::sub"))
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
alpha = op.CastLike(alpha, other)
Expand Down

0 comments on commit 662af2a

Please sign in to comment.