diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ddd836c4a..8f99233d3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6393,25 +6393,18 @@ def aten_ones_like( device: str = "", pin_memory: bool = False, ) -> TTensor: - """ones_like. + """ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype before calling this function. """ - # ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. + if dtype is None: + dtype = -1 if dtype == -1: one = op.CastLike(1, self) else: one = op.Cast(1, to=dtype) - return _aten_ones_like_onnx(self, one) - - -@torch_op("aten::ones_like", private=True) -def _aten_ones_like_onnx(self: TTensor, one) -> TTensor: shape = op.Shape(self) return op.Expand(one, shape) @@ -8861,6 +8854,8 @@ def aten_zeros_like(self: TTensor, dtype: int = -1) -> TTensor: # NOTE: trace_only because both if branches need to be the same type, but we have # a cast in the if branch. + if dtype is None: + dtype = -1 if dtype == -1: zero = op.CastLike(0, self) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index b4f42096e..a26bcbe7c 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2046,10 +2046,11 @@ def aten_sigmoid_backward(grad_output: TensorType, output: TensorType) -> Tensor raise NotImplementedError() -def aten_silu(self: TensorType) -> TensorType: +@torch_op("aten::silu", traceable=True) +def aten_silu(self: TFloat) -> TFloat: """silu(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Mul(self, op.Sigmoid(self)) def aten_silu_backward(grad_output: TensorType, self: TensorType) -> TensorType: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 999211f83..b7038ada7 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1390,6 +1390,7 @@ def _where_input_wrangler( TorchLibOpInfo("select_scatter", core_ops.aten_select_scatter), TorchLibOpInfo("sigmoid", core_ops.aten_sigmoid), TorchLibOpInfo("sign", core_ops.aten_sign), + TorchLibOpInfo("nn.functional.silu", nn_ops.aten_silu), TorchLibOpInfo("sin", core_ops.aten_sin), TorchLibOpInfo( "sinc", special_ops.aten_special_sinc, tolerance={torch.float16: (1e-2, 6e-4)}