Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Fix registrations 3/n #1740

Merged
merged 15 commits into from
Jul 22, 2024
89 changes: 66 additions & 23 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def aten_acosh(self: TFloat) -> TFloat:
return op.Acosh(self)


@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"))
@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), traceable=True)
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision
Expand All @@ -173,7 +173,9 @@ def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
return op.Add(self, other)


@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"), trace_only=True, complex=True)
@torch_op(
("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True, complex=True
)
def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""

Expand Down Expand Up @@ -233,7 +235,7 @@ def aten_addcmul(
return op.Add(self, op.Mul(op.Mul(value, tensor1), tensor2))


@torch_op("aten::addmm")
@torch_op("aten::addmm", traceable=True)
def aten_addmm(
self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
Expand Down Expand Up @@ -1140,7 +1142,7 @@ def aten_batch_norm_update_stats(
raise NotImplementedError()


@torch_op("aten::bernoulli")
@torch_op("aten::bernoulli", traceable=True)
def aten_bernoulli(self: TFloat) -> TFloat:
"""Proximal implementation of aten::bernoulli.default

Expand Down Expand Up @@ -1212,7 +1214,8 @@ def aten_binomial(
"aten::bitwise_and.Scalar",
"aten::bitwise_and.Scalar_Tensor",
"_operator::and_",
)
),
traceable=True,
)
def aten_bitwise_and(self: TInt, other: TInt) -> TInt:
"""bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -1226,7 +1229,8 @@ def aten_bitwise_and(self: TInt, other: TInt) -> TInt:
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
)
),
traceable=True,
)
def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16:
"""bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -1244,7 +1248,8 @@ def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16:
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
)
),
traceable=True,
)
def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32:
"""bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -1262,7 +1267,8 @@ def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32:
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
)
),
traceable=True,
)
def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64:
"""bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -1280,7 +1286,8 @@ def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64:
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
)
),
traceable=True,
)
def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8:
"""bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -1293,7 +1300,7 @@ def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8:
return op.Cast(result, to=INT8.dtype)


@torch_op("aten::bitwise_not")
@torch_op("aten::bitwise_not", traceable=True)
def aten_bitwise_not(self: TInt) -> TInt:
"""bitwise_not(Tensor self) -> Tensor"""
# logical_not implements the BOOL variant
Expand All @@ -1307,7 +1314,8 @@ def aten_bitwise_not(self: TInt) -> TInt:
"aten::bitwise_or.Scalar",
"aten::bitwise_or.Scalar_Tensor",
"_operator::or_",
)
),
traceable=True,
)
def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
"""bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand Down Expand Up @@ -1440,7 +1448,8 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8:
"aten::bitwise_xor.Tensor",
"aten::bitwise_xor.Scalar",
"aten::bitwise_xor.Scalar_Tensor",
)
),
traceable=True,
)
def aten_bitwise_xor(self: TInt, other: TInt) -> TInt:
"""bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand Down Expand Up @@ -3480,7 +3489,7 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType
raise NotImplementedError()


@torch_op("aten::fill.Tensor")
@torch_op(("aten::fill.Tensor", "aten::fill.Sclaar"), traceable=True)
def aten_fill(self: TTensor, value: TTensor) -> TTensor:
"""fill.Tensor(Tensor self, Tensor value) -> Tensor"""

Expand Down Expand Up @@ -4834,9 +4843,11 @@ def aten_logical_not(self: BOOL) -> BOOL:
"aten::bitwise_or.Tensor",
"aten::bitwise_or.Scalar",
"aten::bitwise_or.Scalar_Tensor",
"aten::add",
"aten::add.Tensor",
)
"aten::add.Scalar",
"_operator::add",
),
traceable=True,
)
def aten_logical_or(self: BOOL, other: BOOL) -> BOOL:
"""logical_or(Tensor self, Tensor other) -> Tensor"""
Expand Down Expand Up @@ -5544,14 +5555,20 @@ def aten_msort(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"), traceable=True)
@torch_op(
("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
traceable=True,
)
def aten_mul(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Mul(self, other)


@torch_op(("aten::mul", "aten::mul.Tensor"))
@torch_op(
("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
traceable=True,
)
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""

Expand All @@ -5561,10 +5578,15 @@ def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
return op.And(self, other)


@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"), complex=True)
@torch_op(
("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
traceable=True,
complex=True,
)
def aten_mul_complex(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""

# TODO(justinchuby): Maybe use Split to simplify the logic
self_real = op.Slice(self, [0], [1], axes=[-1])
self_imag = op.Slice(self, [1], [2], axes=[-1])
other_real = op.Slice(other, [0], [1], axes=[-1])
Expand Down Expand Up @@ -6580,7 +6602,7 @@ def aten_prelu_backward(
raise NotImplementedError()


@torch_op(("aten::prod.dim_int"), trace_only=True)
@torch_op("aten::prod.dim_int", trace_only=True)
def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal:
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""

Expand Down Expand Up @@ -7966,7 +7988,15 @@ def aten_stft(
return result


@torch_op(("aten::sub.Tensor", "aten::subtract.Tensor", "_operator::sub"))
@torch_op(
(
"aten::sub.Tensor",
"aten::sub.Scalar",
"aten::subtract.Tensor",
"aten::subtract.Scalar",
"_operator::sub",
)
)
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
alpha = op.CastLike(alpha, other)
Expand All @@ -7976,7 +8006,13 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:


@torch_op(
("aten::sub.Tensor", "aten::subtract.Tensor", "_operator::sub"),
(
"aten::sub.Tensor",
"aten::sub.Scalar",
"aten::subtract.Tensor",
"aten::subtract.Scalar",
"_operator::sub",
),
trace_only=True,
complex=True,
)
Expand Down Expand Up @@ -8062,7 +8098,7 @@ def aten_swapdims(self: TensorType, dim0: int, dim1: int) -> TensorType:
raise NotImplementedError()


@torch_op("aten::sym_size")
@torch_op("aten::sym_size.int")
def aten_sym_size(self: TReal, dim: int = 0) -> TReal:
"""sym_size(Tensor self, int dim) -> Tensor"""
# NOTE: onnxscript doesn't support attribute process,
Expand Down Expand Up @@ -8846,7 +8882,14 @@ def reshape_to_2d(tensor):
return op.ConcatFromSequence(tensors_2d, axis=0)


@torch_op(("aten::where", "aten::where.self"))
@torch_op(
(
"aten::where.Scalar",
"aten::where.ScalarSelf",
"aten::where.ScalarOther",
"aten::where.self",
)
)
def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor:
"""where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""

Expand Down
Loading