From 577f51a2618c2f5503c2a40a8f16424fb3ade0f8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 15 Nov 2023 17:35:18 -0800 Subject: [PATCH 1/3] Implement softplus | feat(torchlib) (#1157) --- onnxscript/function_libs/torch_lib/ops/nn.py | 7 +++++-- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 5 +++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index eeea85af6..d4d28059e 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2038,10 +2038,13 @@ def aten_soft_margin_loss_backward( raise NotImplementedError() -def aten_softplus(self: TensorType, beta: float = 1.0, threshold: float = 20.0) -> TensorType: +@torch_op("aten::softplus") +def aten_softplus(self: TFloat, beta: float = 1.0, threshold: float = 20.0) -> TFloat: """softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor""" - raise NotImplementedError() + self_scaled = self * beta + softplus = op.Softplus(self_scaled) / beta + return op.Where(self_scaled > threshold, self, softplus) def aten_softplus_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index a42a78090..c50d75238 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1345,6 +1345,11 @@ def _where_input_wrangler( reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", test_class_name="TestOutputConsistencyFullGraph", ), + TorchLibOpInfo("nn.functional.softplus", nn_ops.aten_softplus).xfail( + dtypes=(torch.float16,), + reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16449", + test_class_name="TestOutputConsistencyEager", + ), TorchLibOpInfo( "split_with_sizes", core_ops.aten_split_with_sizes, From 804ed0192ae8d3a86380f1fd29ef7f182dc751ef Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 Nov 2023 17:31:25 -0800 Subject: [PATCH 2/3] Fix `aten_copy` dtype | fix(torchlib) (#1164) Cast the output of `aten_copy` to `self`'s type. Fixes https://github.com/microsoft/onnxscript/issues/1162 --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1b2ee6bea..8d61be90b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2048,13 +2048,12 @@ def aten_convolution_overrideable( @torch_op("aten::copy") def aten_copy( self: TTensor, - src: TTensor, + src: TTensor2, non_blocking: bool = False, # pylint: disable=unused-argument ) -> TTensor: """copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor""" - self = op.Identity(src) - return self + return op.CastLike(src, self) @torch_op("aten::_to_copy", trace_only=True) From 10f9a1fd91f0358bcc2a61e2c3cf6639641be419 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 Nov 2023 21:45:54 -0800 Subject: [PATCH 3/3] Clean up Shape calls | chore(torchlib) (#1163) Update calls to `Shape` to use the `start` and `end` arguments to simplify the graph and avoid `Gather` nodes. --- onnxscript/function_libs/torch_lib/ops/core.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8d61be90b..cf2fae3db 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4058,12 +4058,9 @@ def aten_index_put_bool( # change array([F,F,T,F,F]) to array([2]) index = op.ArgMax(index_int) # assume index only have 1 True # change array([2]) to array([2,2,2,2,2]) - self_dim_1 = op.Gather(op.Shape(self), 1) - index_dim_0 = op.Gather(op.Shape(index), 0) - neg_1 = op.Constant(value_ints=[-1]) - shape = op.Concat( - op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0 - ) + self_dim_1 = op.Shape(self, start=1, end=2) + index_dim_0 = op.Shape(index, start=0, end=1) + shape = op.Concat(self_dim_1, index_dim_0, axis=0) new_ind = op.Expand(index, shape) new_ind_t = op.Transpose(new_ind) @@ -7512,7 +7509,7 @@ def _center_window_around_zeros_if_needed( window: TFloatOrBFloat16, n_fft: int ) -> TFloatOrBFloat16: # first dimension - n_win = op.Gather(op.Shape(window), 0) + n_win = op.Shape(window, start=0, end=1) # Center window around zeros if needed (required by ONNX's STFT) if n_win < n_fft: left = (n_fft - n_win) / 2