Skip to content

Commit

Permalink
aten::var implementation and aten::roll complex support (#1186)
Browse files Browse the repository at this point in the history
As mentioned on
[#1173](#1173), I'm trying
to add aten::var and aten::roll (complex support) in order to export one
model from PyTorch to ONNX. The model uses fft functions, which requires
opset 18 and torch dynamo usage.

fixes #1175 fixes
#1174

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
luisfmnunes and justinchuby authored Nov 30, 2023
1 parent 5ba7efa commit 82d2063
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 2 deletions.
110 changes: 108 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6925,6 +6925,37 @@ def aten_roll(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor
return result


@torch_op("aten::roll", trace_only=True, complex=True)
def aten_roll_complex(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor:
"""roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"""

self_rank = len(self.shape)
if self_rank == 1:
return self

if self.shape[0] == 0: # empty tensor
return self

self_real = op.Slice(self, [0], [1], axes=[-1])
self_imag = op.Slice(self, [1], [2], axes=[-1])
if not dims:
# assert isinstance(shifts, int)
shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts)
shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts)

result = op.Concat(shift_real, shift_imag, axis=-1)

else:
# assert len(shifts) == len(dims), but shifts is a tensor, dims is a list
for i, dim in enumerate(dims):
shift = op.Gather(shifts, i, axis=0)
self_real = _aten_roll_shift_and_dim_onnx(self_real, shift, dim)
self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shift, dim)

result = op.Concat(self_real, self_imag, axis=-1)
return result


@torch_op("aten::roll", private=True)
def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor:
neg_1 = op.Constant(value_ints=[-1])
Expand Down Expand Up @@ -8266,10 +8297,47 @@ def aten_vander(
raise NotImplementedError()


def aten_var(self: TensorType, unbiased: bool = True) -> TensorType:
@torch_op("aten::var", trace_only=True)
def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
"""var(Tensor self, bool unbiased=True) -> Tensor"""

raise NotImplementedError()
# Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction"
# If not this case, should be explicitly set correction value according to unbiased value
return _aten_var_onnx(self, correction=float(unbiased), keepdim=False)


@torch_op("aten::var.dim", trace_only=True)
def aten_var_dim(
self: TReal, dim: int, unbiased: Optional[bool] = True, keepdim: Optional[bool] = False
) -> TReal:
"""var(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor"""

if isinstance(dim, int):
dim = (dim,)
dim_tensor = op.Constant(value_ints=dim)
return _aten_var_dim_onnx(self, dim_tensor, correction=float(unbiased), keepdim=keepdim)


@torch_op("aten::var.correction", trace_only=True)
def aten_var_correction(
self: TReal,
dim: Optional[int] = None,
correction: Optional[float] = None,
keepdim: bool = False,
) -> TReal:
"""var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor"""

if correction is None:
correction = 1.0

if dim is None:
var = _aten_var_onnx(self, correction, keepdim)
else:
if isinstance(dim, int):
dim = (dim,)
dim_tensor = op.Constant(value_ints=dim)
var = _aten_var_dim_onnx(self, dim_tensor, correction, keepdim)
return var


@torch_op("aten::var_mean", trace_only=True)
Expand Down Expand Up @@ -8361,6 +8429,44 @@ def _aten_var_mean_dim_onnx(
return var, mean


@torch_op("aten::var", private=True, traceable=True)
def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TReal:
mean = op.ReduceMean(self, keepdims=keepdim)
sub_mean = op.Sub(self, mean)
sqr_mean = op.Mul(sub_mean, sub_mean)
var = op.ReduceMean(sqr_mean, keepdims=keepdim)
# Adjust var according to correction value
if correction > 0.0:
self_shape = op.Shape(self)
numel_float = op.CastLike(op.ReduceProd(self_shape, keepdims=False), self)
mul = op.Mul(var, numel_float)
sub = op.Sub(numel_float, op.CastLike(correction, self))
var = op.Div(mul, sub)

return var


@torch_op("aten::var.dim", private=True, traceable=True)
def _aten_var_dim_onnx(
self: TReal, dim: INT64, correction: float, keepdim: bool = False
) -> TReal:
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
# Computer mean and var
sub_mean = op.Sub(self, op.ReduceMean(self, dim, keepdims=True))
sqr_mean = op.Mul(sub_mean, sub_mean)
var = op.ReduceMean(sqr_mean, dim, keepdims=keepdim)
# Adjust var according to correction value
if correction > 0.0:
self_shape = op.Shape(self)
dim_size = op.Gather(self_shape, dim, axis=0)
numel_float = op.CastLike(op.ReduceProd(dim_size, keepdims=False), self)
mul = op.Mul(var, numel_float)
sub = op.Sub(numel_float, correction)
var = op.Div(mul, sub)

return var


def aten_vdot(self: TensorType, other: TensorType) -> TensorType:
"""vdot(Tensor self, Tensor other) -> Tensor"""

Expand Down
38 changes: 38 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2075,6 +2075,13 @@ def _where_input_wrangler(
trace_only=True,
input_wrangler=_roll_input_wrangler,
),
TorchLibOpInfo(
"roll",
core_ops.aten_roll_complex,
input_wrangler=_roll_input_wrangler,
trace_only=True,
complex=True,
),
TorchLibOpInfo(
"scatter_reduce",
core_ops.aten_scatter_reduce,
Expand Down Expand Up @@ -2182,6 +2189,36 @@ def _where_input_wrangler(
matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs,
reason="this Aten overload only support when correction attribute exists",
),
TorchLibOpInfo(
"var",
core_ops.aten_var,
trace_only=True,
).xfail(
# kwargs must be empty
matcher=lambda sample: len(sample.kwargs) > 0,
reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs",
),
TorchLibOpInfo(
"var_dim",
core_ops.aten_var_dim,
trace_only=True,
).xfail(
# kwargs["dim"] must exist, kwargs["correction"] must not exist
matcher=lambda sample: not (
sample.kwargs.get("dim", None) is not None
and sample.kwargs.get("correction", None) is None
),
reason="this Aten overload only support with 'dim' argument and without 'correction' argument",
),
TorchLibOpInfo(
"var_correction",
core_ops.aten_var_correction,
trace_only=True,
).skip(
# Don't accept input[1]=bool and 'correction' must be in kwargs
matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs,
reason="this Aten overload only support when correction attribute exists",
),
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True),
)

Expand Down Expand Up @@ -2279,6 +2316,7 @@ def _where_input_wrangler(
ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",))
ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction"))
ops_test_common.duplicate_opinfo(OPS_DB, "var", ("var_dim", "var_correction"))
ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))
ops_test_common.duplicate_opinfo(OPS_DB, "view_as_real", ("view_as_real_copy",))

Expand Down

0 comments on commit 82d2063

Please sign in to comment.