Skip to content

Commit

Permalink
[torchlib] Add missing operators (set 2) (#1733)
Browse files Browse the repository at this point in the history
- [x] aten.special.expm1
- [x] aten.sort
  • Loading branch information
shubhambhokare1 authored Jul 18, 2024
1 parent 9ced95d commit d05d101
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
16 changes: 12 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7699,12 +7699,20 @@ def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
return result


@torch_op("aten::sort", trace_only=True)
def aten_sort(
self: TensorType, dim: int = -1, descending: bool = False
) -> tuple[TensorType, TensorType]:
"""sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)"""
self: TReal, dim: int = -1, descending: bool = False, stable: bool = False
) -> tuple[TReal, INT64]:
"""sort(Tensor self, int dim=-1, bool descending=False, bool stable=False) -> (Tensor values, Tensor indices)"""

raise NotImplementedError()
self_is_scalar = IsScalar(self)
if self_is_scalar:
return op.Identity(self), op.Constant(value_int=0)
shape = op.Shape(self)
dim_size = op.Gather(shape, dim, axis=0)
dim_size = op.Reshape(dim_size, op.Constant(value_ints=[1]))
values, indices = op.TopK(self, dim_size, axis=dim, largest=descending, sorted=True)
return values, indices


def aten_sparse_dim(self: TensorType) -> int:
Expand Down
5 changes: 3 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,11 @@ def aten_special_expit(self: TensorType) -> TensorType:
raise NotImplementedError()


def aten_special_expm1(self: TensorType) -> TensorType:
@torch_op(("aten::expm1", "aten::special_expm"))
def aten_special_expm1(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""special_expm1(Tensor self) -> Tensor"""

raise NotImplementedError()
return op.Sub(op.Exp(self), 1)


def aten_special_gammainc(self: TensorType, other: TensorType) -> TensorType:
Expand Down
7 changes: 7 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,9 @@ def _where_input_wrangler(
TorchLibOpInfo(
"erfc", special_ops.aten_special_erfc, tolerance={torch.float16: (1e-2, 2e-4)}
),
TorchLibOpInfo(
"expm1", special_ops.aten_special_expm1, tolerance={torch.float16: (1e-2, 2e-4)}
),
TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail(
reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223"
),
Expand Down Expand Up @@ -1437,6 +1440,10 @@ def _where_input_wrangler(
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16449",
test_class_name="TestOutputConsistencyEager",
),
TorchLibOpInfo("sort", core_ops.aten_sort).xfail(
dtypes=(torch.float16,),
reason="fixme: Tensor-likes are not close. Tests pass for float32.",
),
TorchLibOpInfo(
"split_with_sizes",
core_ops.aten_split_with_sizes,
Expand Down

0 comments on commit d05d101

Please sign in to comment.