diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index dee4f9baa..3bef9759d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6925,6 +6925,37 @@ def aten_roll(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor return result +@torch_op("aten::roll", trace_only=True, complex=True) +def aten_roll_complex(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor: + """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" + + self_rank = len(self.shape) + if self_rank == 1: + return self + + if self.shape[0] == 0: # empty tensor + return self + + self_real = op.Slice(self, [0], [1], axes=[-1]) + self_imag = op.Slice(self, [1], [2], axes=[-1]) + if not dims: + # assert isinstance(shifts, int) + shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts) + shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts) + + result = op.Concat(shift_real, shift_imag, axis=-1) + + else: + # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list + for i, dim in enumerate(dims): + shift = op.Gather(shifts, i, axis=0) + self_real = _aten_roll_shift_and_dim_onnx(self_real, shift, dim) + self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shift, dim) + + result = op.Concat(self_real, self_imag, axis=-1) + return result + + @torch_op("aten::roll", private=True) def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) @@ -8266,10 +8297,47 @@ def aten_vander( raise NotImplementedError() -def aten_var(self: TensorType, unbiased: bool = True) -> TensorType: +@torch_op("aten::var", trace_only=True) +def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal: """var(Tensor self, bool unbiased=True) -> Tensor""" - raise NotImplementedError() + # Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction" + # If not this case, should be explicitly set correction value according to unbiased value + return _aten_var_onnx(self, correction=float(unbiased), keepdim=False) + + +@torch_op("aten::var.dim", trace_only=True) +def aten_var_dim( + self: TReal, dim: int, unbiased: Optional[bool] = True, keepdim: Optional[bool] = False +) -> TReal: + """var(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor""" + + if isinstance(dim, int): + dim = (dim,) + dim_tensor = op.Constant(value_ints=dim) + return _aten_var_dim_onnx(self, dim_tensor, correction=float(unbiased), keepdim=keepdim) + + +@torch_op("aten::var.correction", trace_only=True) +def aten_var_correction( + self: TReal, + dim: Optional[int] = None, + correction: Optional[float] = None, + keepdim: bool = False, +) -> TReal: + """var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor""" + + if correction is None: + correction = 1.0 + + if dim is None: + var = _aten_var_onnx(self, correction, keepdim) + else: + if isinstance(dim, int): + dim = (dim,) + dim_tensor = op.Constant(value_ints=dim) + var = _aten_var_dim_onnx(self, dim_tensor, correction, keepdim) + return var @torch_op("aten::var_mean", trace_only=True) @@ -8361,6 +8429,44 @@ def _aten_var_mean_dim_onnx( return var, mean +@torch_op("aten::var", private=True, traceable=True) +def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TReal: + mean = op.ReduceMean(self, keepdims=keepdim) + sub_mean = op.Sub(self, mean) + sqr_mean = op.Mul(sub_mean, sub_mean) + var = op.ReduceMean(sqr_mean, keepdims=keepdim) + # Adjust var according to correction value + if correction > 0.0: + self_shape = op.Shape(self) + numel_float = op.CastLike(op.ReduceProd(self_shape, keepdims=False), self) + mul = op.Mul(var, numel_float) + sub = op.Sub(numel_float, op.CastLike(correction, self)) + var = op.Div(mul, sub) + + return var + + +@torch_op("aten::var.dim", private=True, traceable=True) +def _aten_var_dim_onnx( + self: TReal, dim: INT64, correction: float, keepdim: bool = False +) -> TReal: + dim = op.Reshape(dim, op.Constant(value_ints=[-1])) + # Computer mean and var + sub_mean = op.Sub(self, op.ReduceMean(self, dim, keepdims=True)) + sqr_mean = op.Mul(sub_mean, sub_mean) + var = op.ReduceMean(sqr_mean, dim, keepdims=keepdim) + # Adjust var according to correction value + if correction > 0.0: + self_shape = op.Shape(self) + dim_size = op.Gather(self_shape, dim, axis=0) + numel_float = op.CastLike(op.ReduceProd(dim_size, keepdims=False), self) + mul = op.Mul(var, numel_float) + sub = op.Sub(numel_float, correction) + var = op.Div(mul, sub) + + return var + + def aten_vdot(self: TensorType, other: TensorType) -> TensorType: """vdot(Tensor self, Tensor other) -> Tensor""" diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index ce0d4d851..e68745d15 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -2075,6 +2075,13 @@ def _where_input_wrangler( trace_only=True, input_wrangler=_roll_input_wrangler, ), + TorchLibOpInfo( + "roll", + core_ops.aten_roll_complex, + input_wrangler=_roll_input_wrangler, + trace_only=True, + complex=True, + ), TorchLibOpInfo( "scatter_reduce", core_ops.aten_scatter_reduce, @@ -2182,6 +2189,36 @@ def _where_input_wrangler( matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, reason="this Aten overload only support when correction attribute exists", ), + TorchLibOpInfo( + "var", + core_ops.aten_var, + trace_only=True, + ).xfail( + # kwargs must be empty + matcher=lambda sample: len(sample.kwargs) > 0, + reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs", + ), + TorchLibOpInfo( + "var_dim", + core_ops.aten_var_dim, + trace_only=True, + ).xfail( + # kwargs["dim"] must exist, kwargs["correction"] must not exist + matcher=lambda sample: not ( + sample.kwargs.get("dim", None) is not None + and sample.kwargs.get("correction", None) is None + ), + reason="this Aten overload only support with 'dim' argument and without 'correction' argument", + ), + TorchLibOpInfo( + "var_correction", + core_ops.aten_var_correction, + trace_only=True, + ).skip( + # Don't accept input[1]=bool and 'correction' must be in kwargs + matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, + reason="this Aten overload only support when correction attribute exists", + ), TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True), ) @@ -2279,6 +2316,7 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction")) +ops_test_common.duplicate_opinfo(OPS_DB, "var", ("var_dim", "var_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",)) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_real", ("view_as_real_copy",))