diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 475458892..e4a29c030 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7699,12 +7699,20 @@ def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: return result +@torch_op("aten::sort", trace_only=True) def aten_sort( - self: TensorType, dim: int = -1, descending: bool = False -) -> tuple[TensorType, TensorType]: - """sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)""" + self: TReal, dim: int = -1, descending: bool = False, stable: bool = False +) -> tuple[TReal, INT64]: + """sort(Tensor self, int dim=-1, bool descending=False, bool stable=False) -> (Tensor values, Tensor indices)""" - raise NotImplementedError() + self_is_scalar = IsScalar(self) + if self_is_scalar: + return op.Identity(self), op.Constant(value_int=0) + shape = op.Shape(self) + dim_size = op.Gather(shape, dim, axis=0) + dim_size = op.Reshape(dim_size, op.Constant(value_ints=[1])) + values, indices = op.TopK(self, dim_size, axis=dim, largest=descending, sorted=True) + return values, indices def aten_sparse_dim(self: TensorType) -> int: diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index bf4746261..980cf881e 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -130,10 +130,11 @@ def aten_special_expit(self: TensorType) -> TensorType: raise NotImplementedError() -def aten_special_expm1(self: TensorType) -> TensorType: +@torch_op(("aten::expm1", "aten::special_expm")) +def aten_special_expm1(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """special_expm1(Tensor self) -> Tensor""" - raise NotImplementedError() + return op.Sub(op.Exp(self), 1) def aten_special_gammainc(self: TensorType, other: TensorType) -> TensorType: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 773c19f1d..bad3e8eb6 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -810,6 +810,9 @@ def _where_input_wrangler( TorchLibOpInfo( "erfc", special_ops.aten_special_erfc, tolerance={torch.float16: (1e-2, 2e-4)} ), + TorchLibOpInfo( + "expm1", special_ops.aten_special_expm1, tolerance={torch.float16: (1e-2, 2e-4)} + ), TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail( reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223" ), @@ -1437,6 +1440,10 @@ def _where_input_wrangler( reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16449", test_class_name="TestOutputConsistencyEager", ), + TorchLibOpInfo("sort", core_ops.aten_sort).xfail( + dtypes=(torch.float16,), + reason="fixme: Tensor-likes are not close. Tests pass for float32.", + ), TorchLibOpInfo( "split_with_sizes", core_ops.aten_split_with_sizes,