Skip to content

Commit

Permalink
[torchlib] Implement missing operators (set1) (#1706)
Browse files Browse the repository at this point in the history
Implement missing operators uncovered by torch.onnx tests as per #1644

- [x] Implement <OpOverload(op='aten.fmod', overload='Scalar')>
- [x] Implement <OpOverload(op='aten.fmod', overload='Tensor')>
- [x] Implement <OpOverload(op='aten.glu', overload='default')>
@shubhambhokare1
- [x] Implement <OpOverload(op='aten.le', overload='Scalar')>
- [x] Implement <OpOverload(op='aten.lerp', overload='Scalar')>
- [x] Implement <OpOverload(op='aten.linalg_cross', overload='default')>
- [x] Implement <OpOverload(op='aten.mv', overload='default')>
- [x] Implement <OpOverload(op='aten.pow', overload='Scalar')>

- [x] Implement <OpOverload(op='aten.remainder', overload='Scalar')>
- [x] Implement <OpOverload(op='aten.remainder', overload='Tensor')>
- [x] Implement <OpOverload(op='aten.silu', overload='default')>
- [x] Implement <OpOverload(op='aten.unsafe_split', overload='Tensor')>

[**NOT PART OF THIS PR**] Requires adding implementation functions in
torchlib eventually (not currently high in priority)

- [ ] Implement `<OpOverload(op='aten.__rshift__', overload='Scalar')>`
- [ ] Implement <OpOverload(op='aten._linalg_det', overload='default')>
- [ ] Implement <OpOverload(op='aten._linalg_slogdet',
overload='default')>
- [ ] Implement <OpOverload(op='aten._prelu_kernel',
overload='default')>
- [ ] Implement <OpOverload(op='aten.add', overload='Scalar')>
- [ ] Implement <OpOverload(op='aten.add', overload='Tensor')>
- [ ] Implement <OpOverload(op='aten.affine_grid_generator',
overload='default')>
- [ ] Implement <OpOverload(op='aten.aminmax', overload='default')>
- [ ] Implement <OpOverload(op='aten.binary_cross_entropy_with_logits',
overload='default')>
- [ ] Implement <OpOverload(op='aten.bitwise_and', overload='Tensor')>
- [ ] Implement <OpOverload(op='aten.bucketize', overload='Tensor')>
- [ ] Implement <OpOverload(op='aten.conv_tbc', overload='default')>
- [ ] Implement
<OpOverload(op='aten.fake_quantize_per_tensor_affine_cachemask',
overload='default')>
- [ ] Implement <OpOverload(op='aten.fill', overload='Scalar')>
- [ ] Implement <OpOverload(op='aten.index_add', overload='default')>
- [ ] Implement <OpOverload(op='aten.index_copy', overload='default')>
- [ ] Implement <OpOverload(op='aten.index_fill',
overload='int_Scalar')>
- [ ] Implement <OpOverload(op='aten.index_put', overload='default')>
- [ ] Implement <OpOverload(op='aten.masked_scatter',
overload='default')>
- [ ] Implement <OpOverload(op='aten.masked_select',
overload='default')>
- [ ] Implement <OpOverload(op='aten.prod', overload='dim_int')>
- [ ] Implement <OpOverload(op='aten.rsub', overload='Tensor')>
- [ ] Implement <OpOverload(op='aten.scatter', overload='src')>
- [ ] Implement <OpOverload(op='aten.scatter', overload='value')>
- [ ] Implement <OpOverload(op='aten.sort', overload='default')>
- [ ] Implement <OpOverload(op='aten.std', overload='correction')>
- [ ] Implement <OpOverload(op='aten.std_mean', overload='correction')>
- [ ] Implement <OpOverload(op='aten.sym_size', overload='int')>
- [ ] Implement <OpOverload(op='aten.take', overload='default')>
- Implement <OpOverload(op='aten._adaptive_avg_pool2d',
overload='default')>
- Implement <OpOverload(op='aten._cdist_forward', overload='default')>
- Implement <OpOverload(op='aten._convolution', overload='default')>
- Implement
<OpOverload(op='aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams',
overload='default')>
- Implement <OpOverload(op='aten.grid_sampler_3d', overload='default')>
- Implement <OpOverload(op='aten.hann_window', overload='default')>
- Implement <OpOverload(op='aten.im2col', overload='default')>
- Implement <OpOverload(op='aten.repeat_interleave', overload='Tensor')>
- Implement <OpOverload(op='torchvision.nms', overload='default')>
- Implement <OpOverload(op='torchvision.roi_align', overload='default')>
- Implement <OpOverload(op='torchvision.roi_pool', overload='default')>
- [ ] Implement <OpOverload(op='aten.nan_to_num', overload='default')>
- [ ] Implement <OpOverload(op='aten.nll_loss2d_forward',
overload='default')>
- [ ] Implement <OpOverload(op='aten.nll_loss_forward',
overload='default')>
- [ ] Implement <OpOverload(op='aten.norm',
overload='ScalarOpt_dim_dtype')>
- [ ] Implement <OpOverload(op='aten.pixel_unshuffle',
overload='default')>

Add operator registration

- [ ] aten::empty
- [ ] aten::fill
- [ ] aten::getitem
- [ ] aten::normal
- [ ] aten::rsub
- [ ] aten::scatter_reduce
- [ ] aten::select
- [ ] aten::slice
- [ ] aten::softmax
- [ ] aten::subtract
- [ ] aten::transpose
- [ ] aten::unbind
  • Loading branch information
shubhambhokare1 authored Jul 17, 2024
1 parent fb7dea4 commit f8ee736
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
43 changes: 21 additions & 22 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,23 +2236,13 @@ def aten_cov(
raise NotImplementedError()


@torch_op("aten::cross")
@torch_op(("aten::cross", "aten::linalg_cross"))
def aten_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:
"""cross(Tensor self, Tensor other, int? dim=None) -> Tensor"""

zero = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
two = op.Constant(value_ints=[2])
three = op.Constant(value_ints=[3])
axes = op.Expand(dim, op.Constant(value_ints=[1]))

# Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073
a1 = op.Slice(self, zero, one, axes)
a2 = op.Slice(self, one, two, axes)
a3 = op.Slice(self, two, three, axes)
b1 = op.Slice(other, zero, one, axes)
b2 = op.Slice(other, one, two, axes)
b3 = op.Slice(other, two, three, axes)
a1, a2, a3 = op.Split(self, axis=dim, num_outputs=3)
b1, b2, b3 = op.Split(other, axis=dim, num_outputs=3)
# Broadcasting is implicitly supported by Mul
c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2))
c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3))
Expand Down Expand Up @@ -3571,7 +3561,7 @@ def aten_fmin(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::fmod")
@torch_op(("aten::fmod.Tensor", "aten::fmod.Scalar"))
def aten_fmod(self: TRealOrUInt8, other: TRealOrUInt8) -> TRealOrUInt8:
"""fmod.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -4659,7 +4649,7 @@ def aten_le(self: TReal, other: TReal) -> BOOL:
return op.LessOrEqual(self, other)


@torch_op(("aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand All @@ -4672,10 +4662,17 @@ def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
return op.Or(other, op.Not(self))


def aten_lerp(self: TensorType, end: TensorType, weight: TensorType) -> TensorType:
@torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar"))
def aten_lerp(self: TTensor, end: TTensor, weight: TTensor) -> TTensor:
"""lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"""

raise NotImplementedError()
weight = op.CastLike(weight, self)
diff = op.Sub(end, self)
return op.Where(
op.Less(weight, 0.5),
op.Add(self, op.Mul(weight, diff)),
op.Sub(end, op.Mul(diff, op.Sub(1.0, weight))),
)


def aten_lgamma(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -5619,10 +5616,11 @@ def aten_multiply(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::mv")
def aten_mv(self: TensorType, vec: TensorType) -> TensorType:
"""mv(Tensor self, Tensor vec) -> Tensor"""

raise NotImplementedError()
return op.MatMul(self, vec)


def aten_mvlgamma(self: TensorType, p: int) -> TensorType:
Expand Down Expand Up @@ -7011,7 +7009,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
raise NotImplementedError()


@torch_op("aten::remainder")
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand All @@ -7024,7 +7022,7 @@ def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrB
return op.Sub(self, op.Mul(rounded_quotient, other))


@torch_op("aten::remainder")
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
def aten_remainder_int(self: TInt, other: TInt) -> TInt:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -8533,10 +8531,11 @@ def aten_unsafe_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType
raise NotImplementedError()


def aten_unsafe_split(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType:
@torch_op(("aten::unsafe_split", "aten::unsafe_split.Tensor"))
def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> Sequence[TTensor]:
"""unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"""

raise NotImplementedError()
return op.SplitToSequence(self, split_size, axis=dim)


def aten_unsafe_split_with_sizes(
Expand Down
5 changes: 3 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from onnxscript import BOOL, FLOAT, INT64
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TTensor
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

Expand All @@ -44,9 +44,10 @@ def aten_linalg_cond(self: TensorType, p: Optional[float] = None) -> TensorType:
raise NotImplementedError()


def aten_linalg_cross(self: TensorType, other: TensorType, dim: int = -1) -> TensorType:
def aten_linalg_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:
"""linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor"""

# Same implementation as aten_cross
raise NotImplementedError()


Expand Down
10 changes: 10 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,11 @@ def _where_input_wrangler(
TorchLibOpInfo("log", core_ops.aten_log),
TorchLibOpInfo("le", core_ops.aten_le),
TorchLibOpInfo("le_bool", core_ops.aten_le_bool),
TorchLibOpInfo(
"lerp",
core_ops.aten_lerp,
tolerance={torch.float16: (2e-3, 2e-1)},
),
TorchLibOpInfo("log10", core_ops.aten_log10),
TorchLibOpInfo("log1p", core_ops.aten_log1p),
TorchLibOpInfo(
Expand Down Expand Up @@ -1020,6 +1025,11 @@ def _where_input_wrangler(
TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True),
TorchLibOpInfo("mul", core_ops.aten_mul),
TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True),
TorchLibOpInfo(
"mv",
core_ops.aten_mv,
tolerance={torch.float16: (3e-2, 1e-2)},
),
TorchLibOpInfo("narrow", core_ops.aten_narrow),
TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout),
TorchLibOpInfo("ne", core_ops.aten_ne),
Expand Down

0 comments on commit f8ee736

Please sign in to comment.