Skip to content

Commit

Permalink
add sort impl
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Jul 17, 2024
1 parent f12d737 commit 7af7d90
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
17 changes: 14 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7699,12 +7699,23 @@ def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
return result


@torch_op("aten::sort", traceable=True)
def aten_sort(
self: TensorType, dim: int = -1, descending: bool = False
) -> tuple[TensorType, TensorType]:
self: TReal, dim: INT64 = -1, descending: bool = False
) -> tuple[TReal, INT64]:
"""sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)"""

raise NotImplementedError()
self_is_scalar = IsScalar(self)

Check warning on line 7708 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L7708

Added line #L7708 was not covered by tests
if self_is_scalar:
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
shape = op.Shape(self)
dim_size = op.Gather(shape, dim, axis=0)
dim_size = op.Reshape(op.Cast(dim_size, to=INT64.dtype), op.Constant(value_ints=[1]))
values, indices = op.TopK(self, dim_size, axis=dim, largest=not descending, sorted=sorted)

Check warning on line 7714 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L7710-L7714

Added lines #L7710 - L7714 were not covered by tests
if self_is_scalar:
values = op.Squeeze(values, op.Constant(value_ints=[0]))
indices = op.Squeeze(indices, op.Constant(value_ints=[0]))
return values, indices

Check warning on line 7718 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L7716-L7718

Added lines #L7716 - L7718 were not covered by tests


def aten_sparse_dim(self: TensorType) -> int:
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,6 +1440,7 @@ def _where_input_wrangler(
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16449",
test_class_name="TestOutputConsistencyEager",
),
TorchLibOpInfo("sort", core_ops.aten_sort),
TorchLibOpInfo(
"split_with_sizes",
core_ops.aten_split_with_sizes,
Expand Down

0 comments on commit 7af7d90

Please sign in to comment.