Skip to content

Commit

Permalink
Add aten_prod function (#1724)
Browse files Browse the repository at this point in the history
The backward routine need aten_prod.dim_int function.
Todo: add test case for this function.
  • Loading branch information
xiaowuhu authored Jul 9, 2024
1 parent c74609c commit 60f2d2c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6583,10 +6583,12 @@ def aten_prelu_backward(
raise NotImplementedError()


def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType:
@torch_op(("aten::prod.dim_int"), trace_only=True)
def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal:
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""

raise NotImplementedError()
# Todo: add test for this function later
return op.ReduceProd(self, axes=[dim], keepdims=keepdim)


def aten_promote_types(type1: int, type2: int) -> int:
Expand Down

0 comments on commit 60f2d2c

Please sign in to comment.