From 8eec5aead1840524be29a0df6e6196632338ce3a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 6 Nov 2023 13:16:08 -0800 Subject: [PATCH] Fix var_mean when input dim is a list | fix(torchlib) (#1127) Previously `var_mean.dim` handles dim only when it is a tuple. It can be a list as well in practice. --- onnxscript/function_libs/torch_lib/ops/core.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index add9dd398..bbcdb1f25 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8015,16 +8015,15 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: @torch_op("aten::var_mean.dim", trace_only=True) def aten_var_mean_dim( - self: TReal, dim: Optional[int], unbiased: bool = True, keepdim: bool = False + self: TReal, dim: int, unbiased: bool = True, keepdim: bool = False ) -> Tuple[TReal, TReal]: """var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)""" - # Although dim is Optional in signature, but we assume it must has value for this overload + # Although dim is Optional in signature, but we assume it must have value for this overload # Assert(dim is not None) - if isinstance(dim, Tuple): - dim_tensor = op.Constant(value_ints=dim) - else: - dim_tensor = op.Constant(value_int=dim) + if isinstance(dim, int): + dim = (dim,) + dim_tensor = op.Constant(value_ints=dim) return _aten_var_mean_dim_onnx( self, dim_tensor, correction=float(unbiased), keepdim=keepdim ) @@ -8045,10 +8044,9 @@ def aten_var_mean_correction( if dim is None: var, mean = _aten_var_mean_onnx(self, correction, keepdim) else: - if isinstance(dim, Tuple): - dim_tensor = op.Constant(value_ints=dim) - else: - dim_tensor = op.Constant(value_int=dim) + if isinstance(dim, int): + dim = (dim,) + dim_tensor = op.Constant(value_ints=dim) var, mean = _aten_var_mean_dim_onnx(self, dim_tensor, correction, keepdim) return var, mean