From aded32433e52db8dea0483db0fdc4e3835a13665 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 4 Mar 2024 13:15:25 -0800 Subject: [PATCH] Implement ATen complex and polar | feat(torchlib) (#1286) aten::complex has a broadcasting behavior which is implemented here. **NOTE:** Optimizations should consider eliminating the `Expand` node when the broadcasted shape is the same as the input shape. Fixes https://github.com/pytorch/pytorch/issues/121100 --- .../function_libs/torch_lib/ops/core.py | 24 +++++++++++++++---- .../function_libs/torch_lib/ops_test_data.py | 2 ++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6ac134552..9d21cad87 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1641,10 +1641,23 @@ def aten_combinations( raise NotImplementedError() -def aten_complex(real: TensorType, imag: TensorType) -> TensorType: +@torch_op("aten::complex", private=True) +def _aten_complex(real: TFloat, imag: TFloat) -> TFloat: + """Non-broadcasting complex constructor.""" + + return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1) + + +@torch_op("aten::complex", trace_only=True) +def aten_complex(real: TFloat, imag: TFloat) -> TFloat: """complex(Tensor real, Tensor imag) -> Tensor""" - raise NotImplementedError() + # Broadcast the real and imaginary parts to the same shape + broadcasted_shape = _shape_of_broadcast_tensors(real, imag) + real = op.Expand(real, broadcasted_shape) + imag = op.Expand(imag, broadcasted_shape) + + return _aten_complex(real, imag) @torch_op("aten::concat") @@ -6385,10 +6398,13 @@ def aten_poisson_nll_loss( raise NotImplementedError() -def aten_polar(abs: TensorType, angle: TensorType) -> TensorType: +@torch_op("aten::polar") +def aten_polar(abs: TFloat, angle: TFloat) -> TFloat: """polar(Tensor abs, Tensor angle) -> Tensor""" - raise NotImplementedError() + real = op.Unsqueeze(op.Mul(abs, op.Cos(angle)), axes=[-1]) + imag = op.Unsqueeze(op.Mul(abs, op.Sin(angle)), axes=[-1]) + return op.Concat(real, imag, axis=-1) def aten_polygamma(n: int, self: TensorType) -> TensorType: diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 8373d784b..65d5b35f2 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -720,6 +720,7 @@ def _where_input_wrangler( reason="fixme (core dump): ORT aborts on scalar inputs to Reduce*-18. https://github.com/microsoft/onnxruntime/issues/16492", ), TorchLibOpInfo("clone", core_ops.aten_clone), + TorchLibOpInfo("complex", core_ops.aten_complex, trace_only=True), TorchLibOpInfo("concat", core_ops.aten_concat).skip( matcher=lambda sample: sample.input[0].equal(torch.tensor([])), reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", @@ -1332,6 +1333,7 @@ def _where_input_wrangler( input_wrangler=_permute_input_wrangler, trace_only=True, ), + TorchLibOpInfo("polar", core_ops.aten_polar), TorchLibOpInfo("pow", core_ops.aten_pow), TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True), TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True),