From 95a400753f1e1435997f4c166f9cc1af92515212 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 9 Jul 2024 07:22:34 +0800 Subject: [PATCH 1/2] Update core.py --- onnxscript/function_libs/torch_lib/ops/core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b8535d46c..1802d95c0 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6583,10 +6583,11 @@ def aten_prelu_backward( raise NotImplementedError() -def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType: +@torch_op(("aten::prod", "aten::prod.dim_int"), trace_only=True) +def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal: """prod(Tensor self, *, ScalarType? dtype=None) -> Tensor""" - raise NotImplementedError() + return op.ReduceProd(self, axes=[dim], keepdims=keepdim) def aten_promote_types(type1: int, type2: int) -> int: From 931059f86e50e750c2ce94ec95295aae64dcad7f Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 9 Jul 2024 14:45:16 +0800 Subject: [PATCH 2/2] Update core.py --- onnxscript/function_libs/torch_lib/ops/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f5fdf2064..80880ceae 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6583,10 +6583,11 @@ def aten_prelu_backward( raise NotImplementedError() -@torch_op(("aten::prod", "aten::prod.dim_int"), trace_only=True) +@torch_op(("aten::prod.dim_int"), trace_only=True) def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal: """prod(Tensor self, *, ScalarType? dtype=None) -> Tensor""" + # Todo: add test for this function later return op.ReduceProd(self, axes=[dim], keepdims=keepdim)