Skip to content

Commit

Permalink
add more impls
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Jun 28, 2024
1 parent 8d78fa9 commit f06f3ea
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
19 changes: 16 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5625,10 +5625,11 @@ def aten_multiply(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::mv")
def aten_mv(self: TensorType, vec: TensorType) -> TensorType:
"""mv(Tensor self, Tensor vec) -> Tensor"""

raise NotImplementedError()
return op.MatMul(self, vec)

Check warning on line 5632 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L5632

Added line #L5632 was not covered by tests


def aten_mvlgamma(self: TensorType, p: int) -> TensorType:
Expand Down Expand Up @@ -6568,7 +6569,14 @@ def aten_positive(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow"))
@torch_op(
(
"aten::pow.Scalar",
"aten::pow.Tensor_Tensor",
"aten::pow.Tensor_Scalar",
"_operator::pow"
)
)
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
"""pow(Tensor self, Tensor exponent) -> Tensor"""

Expand Down Expand Up @@ -7291,10 +7299,15 @@ def aten_rrelu(
raise NotImplementedError()


@torch_op(("aten::__rshift__.Tensor", "aten::__rshift__.Scalar"))
def aten_rshift(self: TensorType, other: TensorType) -> TensorType:
"""__rshift__.Tensor(Tensor self, Tensor other) -> Tensor"""

raise NotImplementedError()
other = op.Cast(other, to=FLOAT.dtype)
two_pow = op.Pow(2.0, other)
two_pow = op.CastLike(two_pow, self)
rshift = op.Div(self, two_pow)
return rshift

Check warning on line 7310 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L7306-L7310

Added lines #L7306 - L7310 were not covered by tests


@torch_op("aten::rsqrt")
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2043,10 +2043,11 @@ def aten_sigmoid_backward(grad_output: TensorType, output: TensorType) -> Tensor
raise NotImplementedError()


@torch_op("aten::silu")

Check warning on line 2046 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L2046

Added line #L2046 was not covered by tests
def aten_silu(self: TensorType) -> TensorType:
"""silu(Tensor self) -> Tensor"""

raise NotImplementedError()
return op.Mul(self, op.Sigmoid(self))


def aten_silu_backward(grad_output: TensorType, self: TensorType) -> TensorType:
Expand Down

0 comments on commit f06f3ea

Please sign in to comment.