diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 78aa0f6e8..80880ceae 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6583,10 +6583,12 @@ def aten_prelu_backward( raise NotImplementedError() -def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType: +@torch_op(("aten::prod.dim_int"), trace_only=True) +def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal: """prod(Tensor self, *, ScalarType? dtype=None) -> Tensor""" - raise NotImplementedError() + # Todo: add test for this function later + return op.ReduceProd(self, axes=[dim], keepdims=keepdim) def aten_promote_types(type1: int, type2: int) -> int: