From db208775cf96b28115cf360075f0818ded6659cc Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 17 Jul 2024 17:40:20 +0000 Subject: [PATCH 1/8] Add expm1 operator --- onnxscript/function_libs/torch_lib/ops/special.py | 5 +++-- tests/function_libs/torch_lib/ops_test_data.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index bf4746261..980cf881e 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -130,10 +130,11 @@ def aten_special_expit(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_special_expm1(self: TensorType) -> TensorType: +@torch_op(("aten::expm1", "aten::special_expm")) +def aten_special_expm1(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """special_expm1(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Sub(op.Exp(self), 1) def aten_special_gammainc(self: TensorType, other: TensorType) -> TensorType: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 773c19f1d..0a180bc48 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -810,6 +810,9 @@ def _where_input_wrangler( TorchLibOpInfo( "erfc", special_ops.aten_special_erfc, tolerance={torch.float16: (1e-2, 2e-4)} ), + TorchLibOpInfo( + "expm1", special_ops.aten_special_expm1, tolerance={torch.float16: (1e-2, 2e-4)} + ), TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail( reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223" ), From ee57318f7450b269f05fa5f6a17f7630faf14693 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 17 Jul 2024 18:14:00 +0000 Subject: [PATCH 2/8] add sort impl --- onnxscript/function_libs/torch_lib/ops/core.py | 17 ++++++++++++++--- tests/function_libs/torch_lib/ops_test_data.py | 1 + 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 475458892..06fe12039 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7699,12 +7699,23 @@ def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: return result +@torch_op("aten::sort", traceable=True) def aten_sort( - self: TensorType, dim: int = -1, descending: bool = False -) -> tuple[TensorType, TensorType]: + self: TReal, dim: INT64 = -1, descending: bool = False +) -> tuple[TReal, INT64]: """sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)""" - raise NotImplementedError() + self_is_scalar = IsScalar(self) + if self_is_scalar: + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + shape = op.Shape(self) + dim_size = op.Gather(shape, dim, axis=0) + dim_size = op.Reshape(op.Cast(dim_size, to=INT64.dtype), op.Constant(value_ints=[1])) + values, indices = op.TopK(self, dim_size, axis=dim, largest=not descending, sorted=sorted) + if self_is_scalar: + values = op.Squeeze(values, op.Constant(value_ints=[0])) + indices = op.Squeeze(indices, op.Constant(value_ints=[0])) + return values, indices def aten_sparse_dim(self: TensorType) -> int: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 0a180bc48..8db28096b 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1440,6 +1440,7 @@ def _where_input_wrangler( reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16449", test_class_name="TestOutputConsistencyEager", ), + TorchLibOpInfo("sort", core_ops.aten_sort), TorchLibOpInfo( "split_with_sizes", core_ops.aten_split_with_sizes, From 735d50909ebe3c948fe24adf163ee504c3dfe5fc Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 17 Jul 2024 18:44:45 +0000 Subject: [PATCH 3/8] small fix --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 06fe12039..b9590763d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7701,9 +7701,9 @@ def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: @torch_op("aten::sort", traceable=True) def aten_sort( - self: TReal, dim: INT64 = -1, descending: bool = False + self: TReal, dim: int = -1, descending: bool = False, stable: bool = False ) -> tuple[TReal, INT64]: - """sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)""" + """sort(Tensor self, int dim=-1, bool descending=False, bool stable=False) -> (Tensor values, Tensor indices)""" self_is_scalar = IsScalar(self) if self_is_scalar: @@ -7711,7 +7711,7 @@ def aten_sort( shape = op.Shape(self) dim_size = op.Gather(shape, dim, axis=0) dim_size = op.Reshape(op.Cast(dim_size, to=INT64.dtype), op.Constant(value_ints=[1])) - values, indices = op.TopK(self, dim_size, axis=dim, largest=not descending, sorted=sorted) + values, indices = op.TopK(self, dim_size, axis=dim, largest=descending, sorted=True) if self_is_scalar: values = op.Squeeze(values, op.Constant(value_ints=[0])) indices = op.Squeeze(indices, op.Constant(value_ints=[0])) From 583de7f8808a19d8b68ba968e8f491672e382515 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 17 Jul 2024 19:38:55 +0000 Subject: [PATCH 4/8] Fix test fails --- tests/function_libs/torch_lib/ops_test_data.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 8db28096b..a2a114405 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1440,7 +1440,13 @@ def _where_input_wrangler( reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16449", test_class_name="TestOutputConsistencyEager", ), - TorchLibOpInfo("sort", core_ops.aten_sort), + TorchLibOpInfo( + "sort", + core_ops.aten_sort + ).xfail( + dtypes=(torch.float16,), + reason="fixme: Tensor-likes are not close. Tests pass for float32." + ), TorchLibOpInfo( "split_with_sizes", core_ops.aten_split_with_sizes, From f7bb621dac7914370fcd86c3e80326aab77dfb5d Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 17 Jul 2024 21:17:14 +0000 Subject: [PATCH 5/8] minor edits --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b9590763d..a761924bd 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7710,7 +7710,7 @@ def aten_sort( self = op.Unsqueeze(self, op.Constant(value_ints=[0])) shape = op.Shape(self) dim_size = op.Gather(shape, dim, axis=0) - dim_size = op.Reshape(op.Cast(dim_size, to=INT64.dtype), op.Constant(value_ints=[1])) + dim_size = op.Reshape(dim_size, op.Constant(value_ints=[1])) values, indices = op.TopK(self, dim_size, axis=dim, largest=descending, sorted=True) if self_is_scalar: values = op.Squeeze(values, op.Constant(value_ints=[0])) From 224a35fa12e5ac7cf0b36ae0f235c5edbc9c1195 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 17 Jul 2024 23:26:47 +0000 Subject: [PATCH 6/8] make trace_only --- onnxscript/function_libs/torch_lib/ops/core.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a761924bd..edc8daad6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7699,7 +7699,7 @@ def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: return result -@torch_op("aten::sort", traceable=True) +@torch_op("aten::sort", trace_only=True) def aten_sort( self: TReal, dim: int = -1, descending: bool = False, stable: bool = False ) -> tuple[TReal, INT64]: @@ -7707,14 +7707,11 @@ def aten_sort( self_is_scalar = IsScalar(self) if self_is_scalar: - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + return op.Identity(self), op.Squeeze(op.Constant(value_ints=[0])) shape = op.Shape(self) dim_size = op.Gather(shape, dim, axis=0) dim_size = op.Reshape(dim_size, op.Constant(value_ints=[1])) values, indices = op.TopK(self, dim_size, axis=dim, largest=descending, sorted=True) - if self_is_scalar: - values = op.Squeeze(values, op.Constant(value_ints=[0])) - indices = op.Squeeze(indices, op.Constant(value_ints=[0])) return values, indices From 6e7e424512b5fb5b8c2c482669f07a5852cd2419 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 17 Jul 2024 23:54:00 +0000 Subject: [PATCH 7/8] lint --- tests/function_libs/torch_lib/ops_test_data.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index a2a114405..bad3e8eb6 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1440,12 +1440,9 @@ def _where_input_wrangler( reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16449", test_class_name="TestOutputConsistencyEager", ), - TorchLibOpInfo( - "sort", - core_ops.aten_sort - ).xfail( + TorchLibOpInfo("sort", core_ops.aten_sort).xfail( dtypes=(torch.float16,), - reason="fixme: Tensor-likes are not close. Tests pass for float32." + reason="fixme: Tensor-likes are not close. Tests pass for float32.", ), TorchLibOpInfo( "split_with_sizes", From f8baeea70a04035b2b880b6435c35d69e2832bad Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Thu, 18 Jul 2024 16:16:09 +0000 Subject: [PATCH 8/8] minor fix --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index edc8daad6..e4a29c030 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7707,7 +7707,7 @@ def aten_sort( self_is_scalar = IsScalar(self) if self_is_scalar: - return op.Identity(self), op.Squeeze(op.Constant(value_ints=[0])) + return op.Identity(self), op.Constant(value_int=0) shape = op.Shape(self) dim_size = op.Gather(shape, dim, axis=0) dim_size = op.Reshape(dim_size, op.Constant(value_ints=[1]))