diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4754588921..06fe120395 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7699,12 +7699,23 @@ def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: return result +@torch_op("aten::sort", traceable=True) def aten_sort( - self: TensorType, dim: int = -1, descending: bool = False -) -> tuple[TensorType, TensorType]: + self: TReal, dim: INT64 = -1, descending: bool = False +) -> tuple[TReal, INT64]: """sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)""" - raise NotImplementedError() + self_is_scalar = IsScalar(self) + if self_is_scalar: + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + shape = op.Shape(self) + dim_size = op.Gather(shape, dim, axis=0) + dim_size = op.Reshape(op.Cast(dim_size, to=INT64.dtype), op.Constant(value_ints=[1])) + values, indices = op.TopK(self, dim_size, axis=dim, largest=not descending, sorted=sorted) + if self_is_scalar: + values = op.Squeeze(values, op.Constant(value_ints=[0])) + indices = op.Squeeze(indices, op.Constant(value_ints=[0])) + return values, indices def aten_sparse_dim(self: TensorType) -> int: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 0a180bc483..8db28096ba 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1440,6 +1440,7 @@ def _where_input_wrangler( reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16449", test_class_name="TestOutputConsistencyEager", ), + TorchLibOpInfo("sort", core_ops.aten_sort), TorchLibOpInfo( "split_with_sizes", core_ops.aten_split_with_sizes,