Skip to content

Commit

Permalink
Update cross impls
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Jul 16, 2024
1 parent bcbf9f7 commit c370e35
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 31 deletions.
23 changes: 4 additions & 19 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,19 +2240,9 @@ def aten_cov(
def aten_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:
"""cross(Tensor self, Tensor other, int? dim=None) -> Tensor"""

zero = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
two = op.Constant(value_ints=[2])
three = op.Constant(value_ints=[3])
axes = op.Expand(dim, op.Constant(value_ints=[1]))

# Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073
a1 = op.Slice(self, zero, one, axes)
a2 = op.Slice(self, one, two, axes)
a3 = op.Slice(self, two, three, axes)
b1 = op.Slice(other, zero, one, axes)
b2 = op.Slice(other, one, two, axes)
b3 = op.Slice(other, two, three, axes)
a1, a2, a3 = op.Split(self, axis=dim, num_outputs=3)
b1, b2, b3 = op.Split(other, axis=dim, num_outputs=3)
# Broadcasting is implicitly supported by Mul
c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2))
c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3))
Expand Down Expand Up @@ -7299,15 +7289,10 @@ def aten_rrelu(
raise NotImplementedError()


@torch_op(("aten::__rshift__.Tensor", "aten::__rshift__.Scalar"))
def aten_rshift(self: TensorType, other: TensorType) -> TensorType:
"""__rshift__.Tensor(Tensor self, Tensor other) -> Tensor"""

other = op.Cast(other, to=FLOAT.dtype)
two_pow = op.Pow(2.0, other)
two_pow = op.CastLike(two_pow, self)
rshift = op.Div(self, two_pow)
return rshift
raise NotImplementedError()


@torch_op("aten::rsqrt")
Expand Down Expand Up @@ -8532,7 +8517,7 @@ def aten_unsafe_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType


@torch_op(("aten::unsafe_split", "aten::unsafe_split.Tensor"))
def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor:
def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> Sequence[TTensor]:
"""unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"""

return op.SplitToSequence(self, split_size, axis=dim)
Expand Down
14 changes: 2 additions & 12 deletions onnxscript/function_libs/torch_lib/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,9 @@ def aten_linalg_cond(self: TensorType, p: Optional[float] = None) -> TensorType:
@torch_op("aten::linalg_cross")
def aten_linalg_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:

zero = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
two = op.Constant(value_ints=[2])
three = op.Constant(value_ints=[3])
axes = op.Expand(dim, op.Constant(value_ints=[1]))

# Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073
a1 = op.Slice(self, zero, one, axes)
a2 = op.Slice(self, one, two, axes)
a3 = op.Slice(self, two, three, axes)
b1 = op.Slice(other, zero, one, axes)
b2 = op.Slice(other, one, two, axes)
b3 = op.Slice(other, two, three, axes)
a1, a2, a3 = op.Split(self, axis=dim, num_outputs=3)
b1, b2, b3 = op.Split(other, axis=dim, num_outputs=3)
# Broadcasting is implicitly supported by Mul
c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2))
c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3))
Expand Down

0 comments on commit c370e35

Please sign in to comment.