From 60f2d2c0e9ed0ed63457580375d2fa9e06b88251 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 9 Jul 2024 14:57:13 +0800 Subject: [PATCH] Add aten_prod function (#1724) The backward routine need aten_prod.dim_int function. Todo: add test case for this function. --- onnxscript/function_libs/torch_lib/ops/core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 78aa0f6e8..80880ceae 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6583,10 +6583,12 @@ def aten_prelu_backward( raise NotImplementedError() -def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType: +@torch_op(("aten::prod.dim_int"), trace_only=True) +def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal: """prod(Tensor self, *, ScalarType? dtype=None) -> Tensor""" - raise NotImplementedError() + # Todo: add test for this function later + return op.ReduceProd(self, axes=[dim], keepdims=keepdim) def aten_promote_types(type1: int, type2: int) -> int: