From 6a42940a6d975474310e7a01710e6e7ee8e5db49 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Mon, 22 Jul 2024 23:33:07 +0000 Subject: [PATCH] add std_mean --- .../function_libs/torch_lib/ops/core.py | 44 ++++++++++++++++++- .../function_libs/torch_lib/ops_test_data.py | 28 ++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4fa43b056..c35898a65 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7827,10 +7827,50 @@ def aten_std(self: TensorType, unbiased: bool = True) -> TensorType: raise NotImplementedError() -def aten_std_mean(self: TensorType, unbiased: bool = True) -> tuple[TensorType, TensorType]: +@torch_op("aten::std_mean", trace_only=True) +def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: """std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)""" - raise NotImplementedError() + # Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction" + # If not this case, should be explicitly set correction value according to unbiased value + var, mean = _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False) + return op.Sqrt(var), mean + + +@torch_op("aten::std_mean.dim", trace_only=True) +def aten_std_mean_dim( + self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False +) -> Tuple[TReal, TReal]: + """std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)""" + + # Although dim is Optional in signature, but we assume it must have value for this overload + # Assert(dim is not None) + var, mean = _aten_var_mean_dim_onnx( + self, dims=dim, correction=float(unbiased), keepdim=keepdim + ) + return op.Sqrt(var), mean + + +@torch_op("aten::std_mean.correction", trace_only=True) +def aten_std_mean_correction( + self: TReal, + # FIXME(justinchuby): Make dim Optional[Sequence[int]] + dim: Optional[int] = None, + correction: Optional[float] = None, + keepdim: bool = False, +) -> Tuple[TReal, TReal]: + """std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)""" + + if correction is None: + correction = 1.0 + + if dim is None: + var, mean = _aten_var_mean_onnx(self, correction=correction, keepdim=keepdim) + else: + var, mean = _aten_var_mean_dim_onnx( + self, dims=dim, correction=correction, keepdim=keepdim + ) + return op.Sqrt(var), mean @torch_op("aten::stft", private=True) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index bad3e8eb6..c17b79261 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1506,6 +1506,33 @@ def _where_input_wrangler( ), TorchLibOpInfo("stack", core_ops.aten_stack), TorchLibOpInfo("stack", core_ops.aten_stack_complex, complex=True), + TorchLibOpInfo( + "std_mean", + core_ops.aten_std_mean, + ).xfail( + # kwargs is empty + matcher=lambda sample: len(sample.kwargs) > 0, + reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs", + ), + TorchLibOpInfo( + "std_mean_dim", + core_ops.aten_std_mean_dim, + ).xfail( + # kwargs["dim"] must exist, kwargs["correction"] must not exist + matcher=lambda sample: not ( + sample.kwargs.get("dim", None) is not None + and sample.kwargs.get("correction", None) is None + ), + reason="this Aten overload only support with 'dim' argument and without 'correction' argument", + ), + TorchLibOpInfo( + "std_mean_correction", + core_ops.aten_std_mean_correction, + ).skip( + # Don't accept input[1]=bool and 'correction' must be in kwargs + matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, + reason="this Aten overload only support when correction attribute exists", + ), TorchLibOpInfo("sub", core_ops.aten_sub), TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB @@ -2300,6 +2327,7 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) +ops_test_common.duplicate_opinfo(OPS_DB, "std_mean", ("std_mean_dim", "std_mean_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "var", ("var_dim", "var_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))