From 3741ea5ed3498b04b44ff2318af8b309c66c072a Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Mon, 22 Jul 2024 17:15:57 -0700 Subject: [PATCH] Add op (std, std.dim, std.correction) | feat(torchlib) (#1747) Add std, std.dim, and std.correction --- .../function_libs/torch_lib/ops/core.py | 38 ++++++++++++++++++- .../function_libs/torch_lib/ops_test_data.py | 28 ++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f984ed6b9..c9fb79f61 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7845,10 +7845,44 @@ def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrStr return op.ConcatFromSequence(tensors, axis=dim, new_axis=1) -def aten_std(self: TensorType, unbiased: bool = True) -> TensorType: +@torch_op("aten::std", trace_only=True) +def aten_std(self: TReal, unbiased: bool = True) -> TReal: """std(Tensor self, bool unbiased=True) -> Tensor""" + var = _aten_var_onnx(self, correction=float(unbiased), keepdim=False) + return op.Sqrt(var) - raise NotImplementedError() + +@torch_op("aten::std.dim", trace_only=True) +def aten_std_dim( + self: TReal, + dim: Sequence[int], + unbiased: Optional[bool] = True, + keepdim: Optional[bool] = False, +) -> TReal: + """std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor""" + + var = _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim) + return op.Sqrt(var) + + +@torch_op("aten::var.correction", trace_only=True) +def aten_std_correction( + self: TReal, + # FIXME(justinchuby): Make dim Optional[Sequence[int]] + dim: Optional[int] = None, + correction: Optional[float] = None, + keepdim: bool = False, +) -> TReal: + """std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor""" + + if correction is None: + correction = 1.0 + + if dim is None: + var = _aten_var_onnx(self, correction=correction, keepdim=keepdim) + else: + var = _aten_var_dim_onnx(self, dims=dim, correction=correction, keepdim=keepdim) + return op.Sqrt(var) def aten_std_mean(self: TensorType, unbiased: bool = True) -> tuple[TensorType, TensorType]: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 8cb245908..0b7415e1f 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2146,6 +2146,33 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="RuntimeError: MKL FFT doesn't support tensors of type: Half", ), + TorchLibOpInfo( + "std", + core_ops.aten_std, + ).xfail( + # kwargs must be empty + matcher=lambda sample: len(sample.kwargs) > 0, + reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs", + ), + TorchLibOpInfo( + "std_dim", + core_ops.aten_std_dim, + ).xfail( + # kwargs["dim"] must exist, kwargs["correction"] must not exist + matcher=lambda sample: not ( + sample.kwargs.get("dim", None) is not None + and sample.kwargs.get("correction", None) is None + ), + reason="this Aten overload only support with 'dim' argument and without 'correction' argument", + ), + TorchLibOpInfo( + "std_correction", + core_ops.aten_std_correction, + ).skip( + # Don't accept input[1]=bool and 'correction' must be in kwargs + matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, + reason="this Aten overload only support when correction attribute exists", + ), TorchLibOpInfo( "sum", core_ops.aten_sum_dim_IntList, @@ -2295,6 +2322,7 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) +ops_test_common.duplicate_opinfo(OPS_DB, "std", ("std_dim", "std_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "var", ("var_dim", "var_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))