From 3f8ccdab9c65e50e9a314c6b62bedb37dc435c53 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 10 Nov 2023 14:11:56 -0800 Subject: [PATCH] Complex support for basic arithmetic | feat(torchlib) (#1144) Implement support for complex inputs for `+`, `-`, `*`, `/`. Updated aten_abs to use Slice instead of Gather for runtime performance. --- .../function_libs/torch_lib/ops/core.py | 96 ++++++++++++++++--- .../function_libs/torch_lib/ops_test_data.py | 6 ++ 2 files changed, 90 insertions(+), 12 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e9bb17217..7240e03e4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -135,13 +135,13 @@ def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8: def aten_abs_complex(self: TRealOrUInt8) -> TRealOrUInt8: """abs(Tensor self) -> Tensor""" # self_real = self[..., 0] - self_real = op.Gather(self, 0, axis=-1) + self_real = op.Slice(self, [0], [1], axes=[-1]) # self_imag = self[..., 1] - self_imag = op.Gather(self, 1, axis=-1) + self_imag = op.Slice(self, [1], [2], axes=[-1]) real_pow = op.Pow(self_real, 2) imag_pow = op.Pow(self_imag, 2) real_plus_imag = op.Add(real_pow, imag_pow) - return op.Sqrt(real_plus_imag) + return op.Squeeze(op.Sqrt(real_plus_imag), axes=[-1]) @torch_op("aten::acos") @@ -167,6 +167,13 @@ def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: return op.Add(self, other) +@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"), trace_only=True, complex=True) +def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: + """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" + + return aten_add(self, other, alpha=alpha) + + @torch_op("aten::addbmm") def aten_addbmm( self: TReal, @@ -2630,6 +2637,39 @@ def aten_div(self: TFloat, other: TFloat) -> TFloat: return op.Div(self, other) +@torch_op( + ( + "aten::div", + "aten::div.Tensor", + "aten::div.Scalar", + "aten::divide", + "aten::true_divide", + "_operator::truediv", + ) +) +def aten_div_complex(self: TFloat, other: TFloat) -> TFloat: + """div.Tensor(Tensor self, Tensor other) -> Tensor""" + + # Complex division. PyTorch type promotion ensures both arguments are complex numbers + self_real = op.Slice(self, [0], [1], axes=[-1]) + self_imag = op.Slice(self, [1], [2], axes=[-1]) + other_real = op.Slice(other, [0], [1], axes=[-1]) + other_imag = op.Slice(other, [1], [2], axes=[-1]) + + # Complex division + # (a + bi) / (c + di) = (ac + bd) / (c^2 + d^2) + (bc - ad) / (c^2 + d^2)i + # https://mathworld.wolfram.com/ComplexDivision.html + ac = op.Mul(self_real, other_real) + bd = op.Mul(self_imag, other_imag) + bc = op.Mul(self_imag, other_real) + ad = op.Mul(self_real, other_imag) + denominator = op.Add(op.Mul(other_real, other_real), op.Mul(other_imag, other_imag)) + real = op.Div(ac + bd, denominator) + imag = op.Div(bc - ad, denominator) + + return op.Concat(real, imag, axis=-1) + + @torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat: """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" @@ -5304,8 +5344,7 @@ def aten_msort(self: TensorType) -> TensorType: @torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul")) def aten_mul(self: TReal, other: TReal) -> TReal: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" - # FIXME(titaiwang): get rid of this when we have type_promotion - other = op.CastLike(other, self) + return op.Mul(self, other) @@ -5319,6 +5358,29 @@ def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: return op.And(self, other) +@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul")) +def aten_mul_complex(self: TReal, other: TReal) -> TReal: + """mul.Tensor(Tensor self, Tensor other) -> Tensor""" + + self_real = op.Slice(self, [0], [1], axes=[-1]) + self_imag = op.Slice(self, [1], [2], axes=[-1]) + other_real = op.Slice(other, [0], [1], axes=[-1]) + other_imag = op.Slice(other, [1], [2], axes=[-1]) + + # Complex multiplication + # (a + bi) * (c + di) = (ac - bd) + (ad + bc)i + + ac = op.Mul(self_real, other_real) + bd = op.Mul(self_imag, other_imag) + ad = op.Mul(self_real, other_imag) + bc = op.Mul(self_imag, other_real) + + real = op.Sub(ac, bd) + imag = op.Add(ad, bc) + + return op.Concat(real, imag, axis=-1) + + @torch_op("aten::multinomial") def aten_multinomial( self: TFloat, @@ -6967,12 +7029,17 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: @torch_op(("aten::rsub", "aten::rsub.Scalar")) def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - # FIXME(titaiwang): get rid of this when we have type_promotion - other = op.CastLike(other, self) - alpha = op.CastLike(alpha, self) + return op.Sub(other, op.Mul(self, alpha)) +@torch_op(("aten::rsub", "aten::rsub.Scalar"), trace_only=True, complex=True) +def aten_rsub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: + """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" + + return aten_rsub(self, other, alpha) + + @torch_op("aten::scalar_tensor") def aten_scalar_tensor(s: float, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var] """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -7563,7 +7630,7 @@ def aten_stft( return result -@torch_op(("aten::sub", "aten::sub.Tensor", "_operator::sub")) +@torch_op(("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub")) def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" alpha = op.CastLike(alpha, other) @@ -7572,10 +7639,15 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: return op.Sub(self, other) -def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1.0) -> TensorType: - """subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" +@torch_op( + ("aten::sub", "aten::sub.Tensor", "aten::subtract", "_operator::sub"), + trace_only=True, + complex=True, +) +def aten_sub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: + """sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - raise NotImplementedError() + return aten_sub(self, other, alpha=alpha) @torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 551068195..bf11d9085 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -500,6 +500,7 @@ def _where_input_wrangler( TorchLibOpInfo("acos", core_ops.aten_acos), TorchLibOpInfo("acosh", core_ops.aten_acosh), TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}), + TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True, trace_only=True), TorchLibOpInfo("addbmm", core_ops.aten_addbmm, tolerance={torch.float32: (2e-5, 2e-5)}), TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv), TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}), @@ -715,6 +716,8 @@ def _where_input_wrangler( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="this variation does not take the rounding_mode argument", ), + TorchLibOpInfo("true_divide", core_ops.aten_div), + TorchLibOpInfo("true_divide", core_ops.aten_div_complex, complex=True), TorchLibOpInfo("div_mode", core_ops.aten_div_mode, trace_only=True) .skip( variant_name="no_rounding_mode", @@ -949,6 +952,7 @@ def _where_input_wrangler( TorchLibOpInfo("mT", core_ops.aten_mT), TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True), TorchLibOpInfo("mul", core_ops.aten_mul), + TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True), TorchLibOpInfo("narrow", core_ops.aten_narrow), TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout), TorchLibOpInfo("ne", core_ops.aten_ne), @@ -1299,6 +1303,7 @@ def _where_input_wrangler( TorchLibOpInfo("round_decimals", core_ops.aten_round_decimals), TorchLibOpInfo("rsqrt", core_ops.aten_rsqrt), TorchLibOpInfo("rsub", core_ops.aten_rsub), + TorchLibOpInfo("rsub", core_ops.aten_rsub_complex, complex=True, trace_only=True), TorchLibOpInfo( "scalar_tensor", core_ops.aten_scalar_tensor, @@ -1392,6 +1397,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("stack", core_ops.aten_stack), TorchLibOpInfo("sub", core_ops.aten_sub), + TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True, trace_only=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB TorchLibOpInfo( "t",