Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Fix linspace and full #1742

Merged
merged 10 commits into from
Jul 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 40 additions & 34 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,6 @@
) -> TensorType:
"""arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.

if dtype == -1:
zero = op.CastLike(0.0, end)
one = op.CastLike(1.0, end)
Expand Down Expand Up @@ -1229,6 +1226,7 @@
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
"_operator::__lshift__",
),
traceable=True,
)
Expand All @@ -1248,6 +1246,7 @@
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
"_operator::__lshift__",
),
traceable=True,
)
Expand All @@ -1267,6 +1266,7 @@
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
"_operator::__lshift__",
),
traceable=True,
)
Expand All @@ -1286,6 +1286,7 @@
"aten::bitwise_left_shift.Tensor",
"aten::bitwise_left_shift.Tensor_Scalar",
"aten::bitwise_left_shift.Scalar_Tensor",
"_operator::__lshift__",
),
traceable=True,
)
Expand Down Expand Up @@ -1329,6 +1330,7 @@
"aten::bitwise_right_shift.Tensor",
"aten::bitwise_right_shift.Tensor_Scalar",
"aten::bitwise_right_shift.Scalar_Tensor",
"_operator::__rshift__",
)
)
def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16:
Expand Down Expand Up @@ -1358,6 +1360,7 @@
"aten::bitwise_right_shift.Tensor",
"aten::bitwise_right_shift.Tensor_Scalar",
"aten::bitwise_right_shift.Scalar_Tensor",
"_operator::__rshift__",
)
)
def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32:
Expand Down Expand Up @@ -1387,6 +1390,7 @@
"aten::bitwise_right_shift.Tensor",
"aten::bitwise_right_shift.Tensor_Scalar",
"aten::bitwise_right_shift.Scalar_Tensor",
"_operator::__rshift__",
)
)
def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64:
Expand Down Expand Up @@ -1419,6 +1423,7 @@
"aten::bitwise_right_shift.Tensor",
"aten::bitwise_right_shift.Tensor_Scalar",
"aten::bitwise_right_shift.Scalar_Tensor",
"_operator::__rshift__",
)
)
def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8:
Expand Down Expand Up @@ -3606,30 +3611,35 @@

@torch_op("aten::full", trace_only=True)
def aten_full(
size: INT64,
fill_value: FLOAT,
size: Union[INT64, INT32],
fill_value: TensorType,
dtype: int = FLOAT.dtype,
layout: str = "",
device: str = "",
pin_memory: bool = False,
):
) -> TensorType:
"""full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

size = op.Cast(size, to=INT64.dtype)
if dtype != -1:
fill_value = op.Cast(fill_value, to=dtype)
if isinstance(size, list) and size == []:
# TODO(justinchuby): Handle empty list better than using isinstance
# size can be empty, meaning a scalar
return fill_value

Check warning on line 3628 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L3628

Added line #L3628 was not covered by tests

size = op.Cast(size, to=INT64.dtype)
return op.Expand(fill_value, size)


@torch_op("aten::full_like", trace_only=True)
def aten_full_like(
self: TTensor,
fill_value: TTensor,
self: TensorType,
fill_value: TensorType,
dtype: int = -1,
layout: str = "",
device: str = "",
pin_memory: bool = False,
) -> TTensor:
) -> TensorType:
"""full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

if dtype == -1:
Expand Down Expand Up @@ -4715,11 +4725,17 @@

@torch_op("aten::linspace", trace_only=True)
def aten_linspace(
start: TFloat, end: TFloat, steps: int, dtype: int = FLOAT.dtype
start: TFloat,
end: TFloat,
steps: int,
dtype: int = FLOAT.dtype,
layout: str = "",
device: str = "",
pin_memory: bool = False,
) -> TensorType:
"""linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

if dtype == -1:
if dtype == -1 or dtype is None:
dtype = FLOAT.dtype

# Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896
Expand All @@ -4743,14 +4759,14 @@
)


@torch_op("aten::log")
@torch_op("aten::log", traceable=True)
def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""log(Tensor self) -> Tensor"""

return op.Log(self)


@torch_op("aten::log10")
@torch_op("aten::log10", traceable=True)
def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""log10(Tensor self) -> Tensor"""

Expand All @@ -4764,21 +4780,21 @@
return op.Log(op.Add(self, 1.0))


@torch_op("aten::log2")
@torch_op("aten::log2", traceable=True)
def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""log2(Tensor self) -> Tensor"""

return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self))


@torch_op("aten::logaddexp")
@torch_op("aten::logaddexp", traceable=True)
def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""logaddexp(Tensor self, Tensor other) -> Tensor"""

return op.Log(op.Add(op.Exp(self), op.Exp(other)))


@torch_op("aten::logaddexp2")
@torch_op("aten::logaddexp2", traceable=True)
def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""logaddexp2(Tensor self, Tensor other) -> Tensor"""
two = op.CastLike(2.0, self)
Expand Down Expand Up @@ -4811,7 +4827,7 @@
return result


@torch_op("aten::logdet")
@torch_op("aten::logdet", traceable=True)
def aten_logdet(self: TFloat) -> TFloat:
"""logdet(Tensor self) -> Tensor"""

Expand All @@ -4824,15 +4840,16 @@
"aten::bitwise_and.Tensor",
"aten::bitwise_and.Scalar",
"aten::bitwise_and.Scalar_Tensor",
)
),
traceable=True,
)
def aten_logical_and(self: BOOL, other: BOOL) -> BOOL:
"""logical_and(Tensor self, Tensor other) -> Tensor"""

return op.And(self, other)


@torch_op(("aten::logical_not", "aten::bitwise_not"))
@torch_op(("aten::logical_not", "aten::bitwise_not"), traceable=True)
def aten_logical_not(self: BOOL) -> BOOL:
"""logical_not(Tensor self) -> Tensor"""

Expand Down Expand Up @@ -4863,7 +4880,8 @@
"aten::bitwise_xor.Tensor",
"aten::bitwise_xor.Scalar",
"aten::bitwise_xor.Scalar_Tensor",
)
),
traceable=True,
)
def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:
"""logical_xor(Tensor self, Tensor other) -> Tensor"""
Expand Down Expand Up @@ -4912,12 +4930,6 @@
return result


def aten_lshift(self: TensorType, other: TensorType) -> TensorType:
"""__lshift__.Tensor(Tensor self, Tensor other) -> Tensor"""

raise NotImplementedError()


def aten_lstm_cell(
input: TensorType,
hx: Sequence[TensorType],
Expand Down Expand Up @@ -6226,7 +6238,7 @@
def aten_new_full(
self: TTensor,
size: INT64,
fill_value: TTensor,
fill_value: TensorType,
dtype: int = -1,
layout: str = "",
device: str = "",
Expand Down Expand Up @@ -7308,12 +7320,6 @@
raise NotImplementedError()


def aten_rshift(self: TensorType, other: TensorType) -> TensorType:
"""__rshift__.Tensor(Tensor self, Tensor other) -> Tensor"""

raise NotImplementedError()


@torch_op("aten::rsqrt")
def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""rsqrt(Tensor self) -> Tensor"""
Expand Down
Loading