Skip to content

Commit

Permalink
Add op (std, std.dim, std.correction) | feat(torchlib) (#1747)
Browse files Browse the repository at this point in the history
Add std, std.dim, and std.correction
  • Loading branch information
titaiwangms authored Jul 23, 2024
1 parent 0e1dca6 commit 3741ea5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
38 changes: 36 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7845,10 +7845,44 @@ def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrStr
return op.ConcatFromSequence(tensors, axis=dim, new_axis=1)


def aten_std(self: TensorType, unbiased: bool = True) -> TensorType:
@torch_op("aten::std", trace_only=True)
def aten_std(self: TReal, unbiased: bool = True) -> TReal:
"""std(Tensor self, bool unbiased=True) -> Tensor"""
var = _aten_var_onnx(self, correction=float(unbiased), keepdim=False)
return op.Sqrt(var)

raise NotImplementedError()

@torch_op("aten::std.dim", trace_only=True)
def aten_std_dim(
self: TReal,
dim: Sequence[int],
unbiased: Optional[bool] = True,
keepdim: Optional[bool] = False,
) -> TReal:
"""std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor"""

var = _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)
return op.Sqrt(var)


@torch_op("aten::var.correction", trace_only=True)
def aten_std_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
dim: Optional[int] = None,
correction: Optional[float] = None,
keepdim: bool = False,
) -> TReal:
"""std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor"""

if correction is None:
correction = 1.0

if dim is None:
var = _aten_var_onnx(self, correction=correction, keepdim=keepdim)
else:
var = _aten_var_dim_onnx(self, dims=dim, correction=correction, keepdim=keepdim)
return op.Sqrt(var)


def aten_std_mean(self: TensorType, unbiased: bool = True) -> tuple[TensorType, TensorType]:
Expand Down
28 changes: 28 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2146,6 +2146,33 @@ def _where_input_wrangler(
dtypes=(torch.float16,),
reason="RuntimeError: MKL FFT doesn't support tensors of type: Half",
),
TorchLibOpInfo(
"std",
core_ops.aten_std,
).xfail(
# kwargs must be empty
matcher=lambda sample: len(sample.kwargs) > 0,
reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs",
),
TorchLibOpInfo(
"std_dim",
core_ops.aten_std_dim,
).xfail(
# kwargs["dim"] must exist, kwargs["correction"] must not exist
matcher=lambda sample: not (
sample.kwargs.get("dim", None) is not None
and sample.kwargs.get("correction", None) is None
),
reason="this Aten overload only support with 'dim' argument and without 'correction' argument",
),
TorchLibOpInfo(
"std_correction",
core_ops.aten_std_correction,
).skip(
# Don't accept input[1]=bool and 'correction' must be in kwargs
matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs,
reason="this Aten overload only support when correction attribute exists",
),
TorchLibOpInfo(
"sum",
core_ops.aten_sum_dim_IntList,
Expand Down Expand Up @@ -2295,6 +2322,7 @@ def _where_input_wrangler(
ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",))
ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",))
ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "std", ("std_dim", "std_correction"))
ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction"))
ops_test_common.duplicate_opinfo(OPS_DB, "var", ("var_dim", "var_correction"))
ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))
Expand Down

0 comments on commit 3741ea5

Please sign in to comment.