Skip to content

Commit

Permalink
Complex support for basic arithmetic | feat(torchlib) (#1144)
Browse files Browse the repository at this point in the history
Implement support for complex inputs for `+`, `-`, `*`, `/`.

Updated aten_abs to use Slice instead of Gather for runtime performance.
  • Loading branch information
justinchuby authored Nov 10, 2023
1 parent 88ee668 commit 3f8ccda
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 12 deletions.
96 changes: 84 additions & 12 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,13 @@ def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8:
def aten_abs_complex(self: TRealOrUInt8) -> TRealOrUInt8:
"""abs(Tensor self) -> Tensor"""
# self_real = self[..., 0]
self_real = op.Gather(self, 0, axis=-1)
self_real = op.Slice(self, [0], [1], axes=[-1])
# self_imag = self[..., 1]
self_imag = op.Gather(self, 1, axis=-1)
self_imag = op.Slice(self, [1], [2], axes=[-1])
real_pow = op.Pow(self_real, 2)
imag_pow = op.Pow(self_imag, 2)
real_plus_imag = op.Add(real_pow, imag_pow)
return op.Sqrt(real_plus_imag)
return op.Squeeze(op.Sqrt(real_plus_imag), axes=[-1])


@torch_op("aten::acos")
Expand All @@ -167,6 +167,13 @@ def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
return op.Add(self, other)


@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"), trace_only=True, complex=True)
def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""

return aten_add(self, other, alpha=alpha)


@torch_op("aten::addbmm")
def aten_addbmm(
self: TReal,
Expand Down Expand Up @@ -2630,6 +2637,39 @@ def aten_div(self: TFloat, other: TFloat) -> TFloat:
return op.Div(self, other)


@torch_op(
(
"aten::div",
"aten::div.Tensor",
"aten::div.Scalar",
"aten::divide",
"aten::true_divide",
"_operator::truediv",
)
)
def aten_div_complex(self: TFloat, other: TFloat) -> TFloat:
"""div.Tensor(Tensor self, Tensor other) -> Tensor"""

# Complex division. PyTorch type promotion ensures both arguments are complex numbers
self_real = op.Slice(self, [0], [1], axes=[-1])
self_imag = op.Slice(self, [1], [2], axes=[-1])
other_real = op.Slice(other, [0], [1], axes=[-1])
other_imag = op.Slice(other, [1], [2], axes=[-1])

# Complex division
# (a + bi) / (c + di) = (ac + bd) / (c^2 + d^2) + (bc - ad) / (c^2 + d^2)i
# https://mathworld.wolfram.com/ComplexDivision.html
ac = op.Mul(self_real, other_real)
bd = op.Mul(self_imag, other_imag)
bc = op.Mul(self_imag, other_real)
ad = op.Mul(self_real, other_imag)
denominator = op.Add(op.Mul(other_real, other_real), op.Mul(other_imag, other_imag))
real = op.Div(ac + bd, denominator)
imag = op.Div(bc - ad, denominator)

return op.Concat(real, imag, axis=-1)


@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat:
"""div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"""
Expand Down Expand Up @@ -5304,8 +5344,7 @@ def aten_msort(self: TensorType) -> TensorType:
@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"))
def aten_mul(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""
# FIXME(titaiwang): get rid of this when we have type_promotion
other = op.CastLike(other, self)

return op.Mul(self, other)


Expand All @@ -5319,6 +5358,29 @@ def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
return op.And(self, other)


@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"))
def aten_mul_complex(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""

self_real = op.Slice(self, [0], [1], axes=[-1])
self_imag = op.Slice(self, [1], [2], axes=[-1])
other_real = op.Slice(other, [0], [1], axes=[-1])
other_imag = op.Slice(other, [1], [2], axes=[-1])

# Complex multiplication
# (a + bi) * (c + di) = (ac - bd) + (ad + bc)i

ac = op.Mul(self_real, other_real)
bd = op.Mul(self_imag, other_imag)
ad = op.Mul(self_real, other_imag)
bc = op.Mul(self_imag, other_real)

real = op.Sub(ac, bd)
imag = op.Add(ad, bc)

return op.Concat(real, imag, axis=-1)


@torch_op("aten::multinomial")
def aten_multinomial(
self: TFloat,
Expand Down Expand Up @@ -6967,12 +7029,17 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
@torch_op(("aten::rsub", "aten::rsub.Scalar"))
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# FIXME(titaiwang): get rid of this when we have type_promotion
other = op.CastLike(other, self)
alpha = op.CastLike(alpha, self)

return op.Sub(other, op.Mul(self, alpha))


@torch_op(("aten::rsub", "aten::rsub.Scalar"), trace_only=True, complex=True)
def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""

return aten_rsub(self, other, alpha)


@torch_op("aten::scalar_tensor")
def aten_scalar_tensor(s: float, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var]
"""scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
Expand Down Expand Up @@ -7563,7 +7630,7 @@ def aten_stft(
return result


@torch_op(("aten::sub", "aten::sub.Tensor", "_operator::sub"))
@torch_op(("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub"))
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
alpha = op.CastLike(alpha, other)
Expand All @@ -7572,10 +7639,15 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
return op.Sub(self, other)


def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1.0) -> TensorType:
"""subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
@torch_op(
("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub"),
trace_only=True,
complex=True,
)
def aten_sub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""

raise NotImplementedError()
return aten_sub(self, other, alpha=alpha)


@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True)
Expand Down
6 changes: 6 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def _where_input_wrangler(
TorchLibOpInfo("acos", core_ops.aten_acos),
TorchLibOpInfo("acosh", core_ops.aten_acosh),
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True, trace_only=True),
TorchLibOpInfo("addbmm", core_ops.aten_addbmm, tolerance={torch.float32: (2e-5, 2e-5)}),
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
Expand Down Expand Up @@ -715,6 +716,8 @@ def _where_input_wrangler(
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None,
reason="this variation does not take the rounding_mode argument",
),
TorchLibOpInfo("true_divide", core_ops.aten_div),
TorchLibOpInfo("true_divide", core_ops.aten_div_complex, complex=True),
TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True)
.skip(
variant_name="no_rounding_mode",
Expand Down Expand Up @@ -949,6 +952,7 @@ def _where_input_wrangler(
TorchLibOpInfo("mT", core_ops.aten_mT),
TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True),
TorchLibOpInfo("mul", core_ops.aten_mul),
TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True),
TorchLibOpInfo("narrow", core_ops.aten_narrow),
TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout),
TorchLibOpInfo("ne", core_ops.aten_ne),
Expand Down Expand Up @@ -1299,6 +1303,7 @@ def _where_input_wrangler(
TorchLibOpInfo("round_decimals", core_ops.aten_round_decimals),
TorchLibOpInfo("rsqrt", core_ops.aten_rsqrt),
TorchLibOpInfo("rsub", core_ops.aten_rsub),
TorchLibOpInfo("rsub", core_ops.aten_rsub_complex, complex=True, trace_only=True),
TorchLibOpInfo(
"scalar_tensor",
core_ops.aten_scalar_tensor,
Expand Down Expand Up @@ -1392,6 +1397,7 @@ def _where_input_wrangler(
),
TorchLibOpInfo("stack", core_ops.aten_stack),
TorchLibOpInfo("sub", core_ops.aten_sub),
TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True, trace_only=True),
# TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB
TorchLibOpInfo(
"t",
Expand Down

0 comments on commit 3f8ccda

Please sign in to comment.