Skip to content

Commit

Permalink
Add linalg function
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Jul 16, 2024
1 parent d592980 commit 1217822
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
5 changes: 3 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8545,10 +8545,11 @@ def aten_unsafe_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType
raise NotImplementedError()


def aten_unsafe_split(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType:
@torch_op(("aten::unsafe_split", "aten::unsafe_split.Tensor"))
def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor:
"""unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"""

raise NotImplementedError()
return op.SplitToSequence(self, split_size, axis=dim)


def aten_unsafe_split_with_sizes(
Expand Down
26 changes: 22 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from onnxscript import BOOL, FLOAT, INT64
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TTensor
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

Expand All @@ -44,10 +44,28 @@ def aten_linalg_cond(self: TensorType, p: Optional[float] = None) -> TensorType:
raise NotImplementedError()


def aten_linalg_cross(self: TensorType, other: TensorType, dim: int = -1) -> TensorType:
"""linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor"""
@torch_op("aten::linalg_cross")
def aten_linalg_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:

raise NotImplementedError()
zero = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
two = op.Constant(value_ints=[2])
three = op.Constant(value_ints=[3])
axes = op.Expand(dim, op.Constant(value_ints=[1]))

# Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073
a1 = op.Slice(self, zero, one, axes)
a2 = op.Slice(self, one, two, axes)
a3 = op.Slice(self, two, three, axes)
b1 = op.Slice(other, zero, one, axes)
b2 = op.Slice(other, one, two, axes)
b3 = op.Slice(other, two, three, axes)
# Broadcasting is implicitly supported by Mul
c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2))
c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3))
c3 = op.Sub(op.Mul(a1, b2), op.Mul(a2, b1))

return op.Concat(c1, c2, c3, axis=dim)


@torch_op(("aten::_linalg_det", "aten::linalg_det", "aten::det"))
Expand Down

0 comments on commit 1217822

Please sign in to comment.