Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Fix registrations 3/n #1740

Merged
merged 15 commits into from
Jul 22, 2024
142 changes: 90 additions & 52 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
return op.Acosh(self)


@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"))
@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"))
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision
Expand All @@ -173,7 +173,9 @@
return op.Add(self, other)


@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"), trace_only=True, complex=True)
@torch_op(
("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True, complex=True
)
def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""

Expand Down Expand Up @@ -1140,7 +1142,7 @@
raise NotImplementedError()


@torch_op("aten::bernoulli")
@torch_op("aten::bernoulli", traceable=True)
def aten_bernoulli(self: TFloat) -> TFloat:
"""Proximal implementation of aten::bernoulli.default

Expand Down Expand Up @@ -1212,7 +1214,8 @@
"aten::bitwise_and.Scalar",
"aten::bitwise_and.Scalar_Tensor",
"_operator::and_",
)
),
traceable=True,
)
def aten_bitwise_and(self: TInt, other: TInt) -> TInt:
"""bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -1226,7 +1229,8 @@
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
)
),
traceable=True,
)
def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16:
"""bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -1244,7 +1248,8 @@
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
)
),
traceable=True,
)
def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32:
"""bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -1262,7 +1267,8 @@
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
)
),
traceable=True,
)
def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64:
"""bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -1280,7 +1286,8 @@
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
)
),
traceable=True,
)
def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8:
"""bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -1293,7 +1300,7 @@
return op.Cast(result, to=INT8.dtype)


@torch_op("aten::bitwise_not")
@torch_op("aten::bitwise_not", traceable=True)
def aten_bitwise_not(self: TInt) -> TInt:
"""bitwise_not(Tensor self) -> Tensor"""
# logical_not implements the BOOL variant
Expand All @@ -1307,7 +1314,8 @@
"aten::bitwise_or.Scalar",
"aten::bitwise_or.Scalar_Tensor",
"_operator::or_",
)
),
traceable=True,
)
def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
"""bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand Down Expand Up @@ -1440,7 +1448,8 @@
"aten::bitwise_xor.Tensor",
"aten::bitwise_xor.Scalar",
"aten::bitwise_xor.Scalar_Tensor",
)
),
traceable=True,
)
def aten_bitwise_xor(self: TInt, other: TInt) -> TInt:
"""bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand Down Expand Up @@ -3480,15 +3489,14 @@
raise NotImplementedError()


@torch_op("aten::fill.Tensor")
def aten_fill(self: TTensor, value: TTensor) -> TTensor:
@torch_op(("aten::fill.Tensor", "aten::fill.Sclaar"))
def aten_fill(self: TTensor, value: TTensor2) -> TTensor:
"""fill.Tensor(Tensor self, Tensor value) -> Tensor"""

# after fill, the self Tensor should keep origianl type
# Cast the value before Expand so it can be constant folded
value = op.CastLike(value, self)
shape = op.Shape(self)
expanded = op.Expand(value, shape)
result = op.CastLike(expanded, self)
return result
return op.Expand(value, shape)


def aten_fix(self: TensorType) -> TensorType:
Expand All @@ -3497,17 +3505,20 @@
raise NotImplementedError()


@torch_op("aten::flip")
def aten_flip(self: TTensor, dims: INT64) -> TTensor:
@torch_op("aten::flip", trace_only=True)
def aten_flip(self: TTensor, dims: Sequence[int]) -> TTensor:
"""flip(Tensor self, int[] dims) -> Tensor"""

shape_dim = op.Shape(dims)
neg_1 = op.Constant(value_int=-1)
starts = op.Expand(neg_1, shape_dim) # something like [-1, -1, -1]
steps = op.Expand(neg_1, shape_dim) # something like [-1, -1, -1]
ends = op.Expand(_INT64_MIN, shape_dim) # something like [-xxx, -xxx, -xxx]
result = op.Slice(self, starts, ends, dims, steps)
return result
if not dims:
# Nothing to flip
return op.Identity(self)

rank = len(dims)
starts = op.Constant(value_ints=[-1] * rank) # something like [-1, -1, -1]
steps = starts # something like [-1, -1, -1]
ends = op.Constant(value_ints=[_INT64_MIN] * rank) # something like [-xxx, -xxx, -xxx]
dims = op.Constant(value_ints=dims)
return op.Slice(self, starts, ends, dims, steps)


def aten_fliplr(self: TensorType) -> TensorType:
Expand All @@ -3529,7 +3540,7 @@
return op.Floor(self)


@torch_op("math::floor")
@torch_op("math::floor", traceable=True)
def python_math_floor(self: TFloatOrBFloat16) -> TInt:
"""floor(Tensor self) -> Tensor"""
floor = op.Floor(self)
Expand Down Expand Up @@ -4834,9 +4845,11 @@
"aten::bitwise_or.Tensor",
"aten::bitwise_or.Scalar",
"aten::bitwise_or.Scalar_Tensor",
"aten::add",
"aten::add.Tensor",
)
"aten::add.Scalar",
"_operator::add",
),
traceable=True,
)
def aten_logical_or(self: BOOL, other: BOOL) -> BOOL:
"""logical_or(Tensor self, Tensor other) -> Tensor"""
Expand Down Expand Up @@ -5544,14 +5557,20 @@
raise NotImplementedError()


@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"), traceable=True)
@torch_op(
("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
traceable=True,
)
def aten_mul(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Mul(self, other)


@torch_op(("aten::mul", "aten::mul.Tensor"))
@torch_op(
("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
traceable=True,
)
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""

Expand All @@ -5561,10 +5580,15 @@
return op.And(self, other)


@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"), complex=True)
@torch_op(
("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
traceable=True,
complex=True,
)
def aten_mul_complex(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""

# TODO(justinchuby): Maybe use Split to simplify the logic
self_real = op.Slice(self, [0], [1], axes=[-1])
self_imag = op.Slice(self, [1], [2], axes=[-1])
other_real = op.Slice(other, [0], [1], axes=[-1])
Expand Down Expand Up @@ -6580,7 +6604,7 @@
raise NotImplementedError()


@torch_op(("aten::prod.dim_int"), trace_only=True)
@torch_op("aten::prod.dim_int", trace_only=True)
def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal:
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""

Expand Down Expand Up @@ -7966,7 +7990,15 @@
return result


@torch_op(("aten::sub.Tensor", "aten::subtract.Tensor", "_operator::sub"))
@torch_op(
(
"aten::sub.Tensor",
"aten::sub.Scalar",
"aten::subtract.Tensor",
"aten::subtract.Scalar",
"_operator::sub",
)
)
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
alpha = op.CastLike(alpha, other)
Expand All @@ -7976,7 +8008,13 @@


@torch_op(
("aten::sub.Tensor", "aten::subtract.Tensor", "_operator::sub"),
(
"aten::sub.Tensor",
"aten::sub.Scalar",
"aten::subtract.Tensor",
"aten::subtract.Scalar",
"_operator::sub",
),
trace_only=True,
complex=True,
)
Expand Down Expand Up @@ -8062,17 +8100,10 @@
raise NotImplementedError()


@torch_op("aten::sym_size")
def aten_sym_size(self: TReal, dim: int = 0) -> TReal:
"""sym_size(Tensor self, int dim) -> Tensor"""
# NOTE: onnxscript doesn't support attribute process,
# so op.Shape(self, start=dim, end=dim + 1) is not supported.
shape = op.Shape(self)
# Reshape helps dim from int to tensor, and
# input arguments support attribute processing.
start = op.Reshape(dim, op.Constant(value_ints=[1]))
end = op.Reshape(dim + 1, op.Constant(value_ints=[1]))
return op.Slice(shape, start, end)
@torch_op("aten::sym_size.int", trace_only=True)
def aten_sym_size(self: TensorType, dim: int = 0) -> INT64:
"""sym_size.int(Tensor self, int dim) -> SymInt"""
return op.Shape(self, end=dim + 1, start=dim)

Check warning on line 8106 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L8106

Added line #L8106 was not covered by tests


def aten_symeig(
Expand Down Expand Up @@ -8116,33 +8147,33 @@
raise NotImplementedError()


@torch_op("aten::tan")
@torch_op("aten::tan", traceable=True)
def aten_tan(self: TFloat) -> TFloat:
"""tan(Tensor self) -> Tensor"""

return op.Tan(self)


@torch_op("aten::tanh")
@torch_op("aten::tanh", traceable=True)
def aten_tanh(self: TFloat) -> TFloat:
"""tanh(Tensor self) -> Tensor"""

return op.Tanh(self)


@torch_op("aten::tensor.bool")
@torch_op("aten::tensor.bool", traceable=True)
def aten_tensor_bool(self: bool, dtype: int) -> TensorType:
tensor = op.Constant(value_int=self)
return op.Cast(tensor, to=dtype)


@torch_op("aten::tensor.float")
@torch_op("aten::tensor.float", traceable=True)
def aten_tensor_float(self: float, dtype: int) -> TensorType:
tensor = op.Constant(value_float=self)
return op.Cast(tensor, to=dtype)


@torch_op("aten::tensor.int")
@torch_op("aten::tensor.int", traceable=True)
def aten_tensor_int(self: int, dtype: int) -> TensorType:
tensor = op.Constant(value_int=self)
return op.Cast(tensor, to=dtype)
Expand Down Expand Up @@ -8846,7 +8877,14 @@
return op.ConcatFromSequence(tensors_2d, axis=0)


@torch_op(("aten::where", "aten::where.self"))
@torch_op(
(
"aten::where.Scalar",
"aten::where.ScalarSelf",
"aten::where.ScalarOther",
"aten::where.self",
)
)
def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor:
"""where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""

Expand Down
13 changes: 4 additions & 9 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,6 @@ def _empty_input_wrangler(
return args, kwargs


def _flip_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Make the dims as tensor
kwargs["dims"] = np.array(kwargs["dims"], dtype=np.int64)
return args, kwargs


def _grid_sample_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -817,7 +809,10 @@ def _where_input_wrangler(
reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223"
),
TorchLibOpInfo("fill", core_ops.aten_fill),
TorchLibOpInfo("flip", core_ops.aten_flip, input_wrangler=_flip_input_wrangler),
TorchLibOpInfo("flip", core_ops.aten_flip).skip(
reason="fixme: size 0 inputs are not handled yet",
matcher=lambda sample: sample.input.numel() == 0,
),
TorchLibOpInfo("floor", core_ops.aten_floor),
TorchLibOpInfo("floor_divide", core_ops.aten_floor_divide).xfail(
dtypes=(torch.float16,),
Expand Down
Loading