From 22165942737e73cb5df35ebfce873eba6f096b0c Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Tue, 16 Jul 2024 16:33:25 +0000 Subject: [PATCH] Update cross impls --- .../function_libs/torch_lib/ops/core.py | 23 ++++--------------- .../function_libs/torch_lib/ops/linalg.py | 14 ++--------- 2 files changed, 6 insertions(+), 31 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f53b2c56dd..927026418b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2240,19 +2240,9 @@ def aten_cov( def aten_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor: """cross(Tensor self, Tensor other, int? dim=None) -> Tensor""" - zero = op.Constant(value_ints=[0]) - one = op.Constant(value_ints=[1]) - two = op.Constant(value_ints=[2]) - three = op.Constant(value_ints=[3]) - axes = op.Expand(dim, op.Constant(value_ints=[1])) - # Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073 - a1 = op.Slice(self, zero, one, axes) - a2 = op.Slice(self, one, two, axes) - a3 = op.Slice(self, two, three, axes) - b1 = op.Slice(other, zero, one, axes) - b2 = op.Slice(other, one, two, axes) - b3 = op.Slice(other, two, three, axes) + a1, a2, a3 = op.Split(self, axis=dim, num_outputs=3) + b1, b2, b3 = op.Split(other, axis=dim, num_outputs=3) # Broadcasting is implicitly supported by Mul c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2)) c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3)) @@ -7299,15 +7289,10 @@ def aten_rrelu( raise NotImplementedError() -@torch_op(("aten::__rshift__.Tensor", "aten::__rshift__.Scalar")) def aten_rshift(self: TensorType, other: TensorType) -> TensorType: """__rshift__.Tensor(Tensor self, Tensor other) -> Tensor""" - other = op.Cast(other, to=FLOAT.dtype) - two_pow = op.Pow(2.0, other) - two_pow = op.CastLike(two_pow, self) - rshift = op.Div(self, two_pow) - return rshift + raise NotImplementedError() @torch_op("aten::rsqrt") @@ -8546,7 +8531,7 @@ def aten_unsafe_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType @torch_op(("aten::unsafe_split", "aten::unsafe_split.Tensor")) -def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor: +def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> Sequence[TTensor]: """unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]""" return op.SplitToSequence(self, split_size, axis=dim) diff --git a/onnxscript/function_libs/torch_lib/ops/linalg.py b/onnxscript/function_libs/torch_lib/ops/linalg.py index 5a9d855925..16e16a5509 100644 --- a/onnxscript/function_libs/torch_lib/ops/linalg.py +++ b/onnxscript/function_libs/torch_lib/ops/linalg.py @@ -47,19 +47,9 @@ def aten_linalg_cond(self: TensorType, p: Optional[float] = None) -> TensorType: @torch_op("aten::linalg_cross") def aten_linalg_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor: - zero = op.Constant(value_ints=[0]) - one = op.Constant(value_ints=[1]) - two = op.Constant(value_ints=[2]) - three = op.Constant(value_ints=[3]) - axes = op.Expand(dim, op.Constant(value_ints=[1])) - # Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073 - a1 = op.Slice(self, zero, one, axes) - a2 = op.Slice(self, one, two, axes) - a3 = op.Slice(self, two, three, axes) - b1 = op.Slice(other, zero, one, axes) - b2 = op.Slice(other, one, two, axes) - b3 = op.Slice(other, two, three, axes) + a1, a2, a3 = op.Split(self, axis=dim, num_outputs=3) + b1, b2, b3 = op.Split(other, axis=dim, num_outputs=3) # Broadcasting is implicitly supported by Mul c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2)) c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3))