Skip to content

Commit

Permalink
[torchlib] Implement silu and fix ones_like (#1718)
Browse files Browse the repository at this point in the history
Needed to export phi-3
  • Loading branch information
justinchuby authored Jul 3, 2024
1 parent e824285 commit ee29e71
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
15 changes: 5 additions & 10 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6393,25 +6393,18 @@ def aten_ones_like(
device: str = "",
pin_memory: bool = False,
) -> TTensor:
"""ones_like.
"""ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype
before calling this function.
"""
# ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.
if dtype is None:
dtype = -1

if dtype == -1:
one = op.CastLike(1, self)
else:
one = op.Cast(1, to=dtype)
return _aten_ones_like_onnx(self, one)


@torch_op("aten::ones_like", private=True)
def _aten_ones_like_onnx(self: TTensor, one) -> TTensor:
shape = op.Shape(self)
return op.Expand(one, shape)

Expand Down Expand Up @@ -8861,6 +8854,8 @@ def aten_zeros_like(self: TTensor, dtype: int = -1) -> TTensor:

# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.
if dtype is None:
dtype = -1

if dtype == -1:
zero = op.CastLike(0, self)
Expand Down
5 changes: 3 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,10 +2046,11 @@ def aten_sigmoid_backward(grad_output: TensorType, output: TensorType) -> Tensor
raise NotImplementedError()


def aten_silu(self: TensorType) -> TensorType:
@torch_op("aten::silu", traceable=True)
def aten_silu(self: TFloat) -> TFloat:
"""silu(Tensor self) -> Tensor"""

raise NotImplementedError()
return op.Mul(self, op.Sigmoid(self))


def aten_silu_backward(grad_output: TensorType, self: TensorType) -> TensorType:
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,7 @@ def _where_input_wrangler(
TorchLibOpInfo("select_scatter", core_ops.aten_select_scatter),
TorchLibOpInfo("sigmoid", core_ops.aten_sigmoid),
TorchLibOpInfo("sign", core_ops.aten_sign),
TorchLibOpInfo("nn.functional.silu", nn_ops.aten_silu),
TorchLibOpInfo("sin", core_ops.aten_sin),
TorchLibOpInfo(
"sinc", special_ops.aten_special_sinc, tolerance={torch.float16: (1e-2, 6e-4)}
Expand Down

0 comments on commit ee29e71

Please sign in to comment.