Skip to content

Commit

Permalink
add std_mean
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms committed Jul 22, 2024
1 parent 2401de4 commit 6a42940
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 2 deletions.
44 changes: 42 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7827,10 +7827,50 @@ def aten_std(self: TensorType, unbiased: bool = True) -> TensorType:
raise NotImplementedError()


def aten_std_mean(self: TensorType, unbiased: bool = True) -> tuple[TensorType, TensorType]:
@torch_op("aten::std_mean", trace_only=True)

Check warning on line 7830 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L7830

Added line #L7830 was not covered by tests
def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
"""std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""

raise NotImplementedError()
# Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction"
# If not this case, should be explicitly set correction value according to unbiased value
var, mean = _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False)
return op.Sqrt(var), mean


@torch_op("aten::std_mean.dim", trace_only=True)
def aten_std_mean_dim(
self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False
) -> Tuple[TReal, TReal]:
"""std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)"""

# Although dim is Optional in signature, but we assume it must have value for this overload
# Assert(dim is not None)
var, mean = _aten_var_mean_dim_onnx(
self, dims=dim, correction=float(unbiased), keepdim=keepdim
)
return op.Sqrt(var), mean

Check warning on line 7851 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L7851

Added line #L7851 was not covered by tests


@torch_op("aten::std_mean.correction", trace_only=True)
def aten_std_mean_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
dim: Optional[int] = None,
correction: Optional[float] = None,
keepdim: bool = False,
) -> Tuple[TReal, TReal]:
"""std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)"""

if correction is None:
correction = 1.0

if dim is None:
var, mean = _aten_var_mean_onnx(self, correction=correction, keepdim=keepdim)
else:
var, mean = _aten_var_mean_dim_onnx(
self, dims=dim, correction=correction, keepdim=keepdim
)
return op.Sqrt(var), mean


@torch_op("aten::stft", private=True)
Expand Down
28 changes: 28 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,33 @@ def _where_input_wrangler(
),
TorchLibOpInfo("stack", core_ops.aten_stack),
TorchLibOpInfo("stack", core_ops.aten_stack_complex, complex=True),
TorchLibOpInfo(
"std_mean",
core_ops.aten_std_mean,
).xfail(
# kwargs is empty
matcher=lambda sample: len(sample.kwargs) > 0,
reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs",
),
TorchLibOpInfo(
"std_mean_dim",
core_ops.aten_std_mean_dim,
).xfail(
# kwargs["dim"] must exist, kwargs["correction"] must not exist
matcher=lambda sample: not (
sample.kwargs.get("dim", None) is not None
and sample.kwargs.get("correction", None) is None
),
reason="this Aten overload only support with 'dim' argument and without 'correction' argument",
),
TorchLibOpInfo(
"std_mean_correction",
core_ops.aten_std_mean_correction,
).skip(
# Don't accept input[1]=bool and 'correction' must be in kwargs
matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs,
reason="this Aten overload only support when correction attribute exists",
),
TorchLibOpInfo("sub", core_ops.aten_sub),
TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True),
# TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB
Expand Down Expand Up @@ -2300,6 +2327,7 @@ def _where_input_wrangler(
ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",))
ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",))
ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "std_mean", ("std_mean_dim", "std_mean_correction"))
ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction"))
ops_test_common.duplicate_opinfo(OPS_DB, "var", ("var_dim", "var_correction"))
ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))
Expand Down

0 comments on commit 6a42940

Please sign in to comment.