Skip to content

Commit

Permalink
Fix var_mean when input dim is a list | fix(torchlib) (#1127)
Browse files Browse the repository at this point in the history
Previously `var_mean.dim` handles dim only when it is a tuple. It can be
a list as well in practice.
  • Loading branch information
justinchuby authored Nov 6, 2023
1 parent 5b50753 commit 8eec5ae
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8015,16 +8015,15 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:

@torch_op("aten::var_mean.dim", trace_only=True)
def aten_var_mean_dim(
self: TReal, dim: Optional[int], unbiased: bool = True, keepdim: bool = False
self: TReal, dim: int, unbiased: bool = True, keepdim: bool = False
) -> Tuple[TReal, TReal]:
"""var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)"""

# Although dim is Optional in signature, but we assume it must has value for this overload
# Although dim is Optional in signature, but we assume it must have value for this overload
# Assert(dim is not None)
if isinstance(dim, Tuple):
dim_tensor = op.Constant(value_ints=dim)
else:
dim_tensor = op.Constant(value_int=dim)
if isinstance(dim, int):
dim = (dim,)
dim_tensor = op.Constant(value_ints=dim)
return _aten_var_mean_dim_onnx(
self, dim_tensor, correction=float(unbiased), keepdim=keepdim
)
Expand All @@ -8045,10 +8044,9 @@ def aten_var_mean_correction(
if dim is None:
var, mean = _aten_var_mean_onnx(self, correction, keepdim)
else:
if isinstance(dim, Tuple):
dim_tensor = op.Constant(value_ints=dim)
else:
dim_tensor = op.Constant(value_int=dim)
if isinstance(dim, int):
dim = (dim,)
dim_tensor = op.Constant(value_ints=dim)
var, mean = _aten_var_mean_dim_onnx(self, dim_tensor, correction, keepdim)
return var, mean

Expand Down

0 comments on commit 8eec5ae

Please sign in to comment.