From 5b507536a41e4ba5f3d7761c429c87d900665907 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Nov 2023 10:08:47 -0700 Subject: [PATCH] Handle complex types in convert_element_type | fix(torchlib) (#1124) Previously when we call e.g. `convert_element_type(a, dtype=complex64)`, we create a node with `COMPLEX64` output. This is not what we want because we always the real representation of a complex number in the ONNX graphs. This change updates the logic to handle when the specified dtype is a complex type, in which case we add an additional axis to the end of the tensor and fill it with zeros to represent the imaginary part. Tested locally: ``` >>> prims_convert_element_type(np.zeros(1), 14) array([[0., 0.]], dtype=float32) >>> prims_convert_element_type(np.zeros(1), 15) array([[0., 0.]]) ``` Fix https://github.com/microsoft/onnxscript/issues/1122 --------- Co-authored-by: Ti-Tai Wang --- .../function_libs/torch_lib/ops/prims.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/prims.py b/onnxscript/function_libs/torch_lib/ops/prims.py index ec4f74bfd..102739ba0 100644 --- a/onnxscript/function_libs/torch_lib/ops/prims.py +++ b/onnxscript/function_libs/torch_lib/ops/prims.py @@ -15,9 +15,12 @@ from onnxscript import INT64 from onnxscript.function_libs.torch_lib.registration import torch_op -from onnxscript.function_libs.torch_lib.tensor_typing import TTensor +from onnxscript.function_libs.torch_lib.tensor_typing import RealType, TTensor from onnxscript.onnx_opset import opset18 as op -from onnxscript.onnx_types import TensorType +from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT, TensorType + +COMPLEX64_TYPE = COMPLEX64.dtype +COMPLEX128_TYPE = COMPLEX128.dtype def prims_abs(self: TensorType) -> TensorType: @@ -217,10 +220,28 @@ def prims_conj_physical(self: TensorType) -> TensorType: @torch_op("prims::convert_element_type") -def prims_convert_element_type(a: TensorType, dtype: int) -> TensorType: +def prims_convert_element_type(a: RealType, dtype: int) -> RealType: """convert_element_type(Tensor a, ScalarType dtype) -> Tensor""" - return op.Cast(a, to=dtype) + if dtype == COMPLEX128_TYPE: + # Cast to the real representation of the complex type + casted = op.Cast(a, to=DOUBLE.dtype) + # Create a complex number + real_part = op.Unsqueeze(casted, axes=[-1]) + imag_part = op.Expand(op.Cast(0.0, to=DOUBLE.dtype), op.Shape(real_part)) + result = op.Concat(real_part, imag_part, axis=-1) + elif dtype == COMPLEX64_TYPE: + # Cast to the real representation of the complex type + casted = op.Cast(a, to=FLOAT.dtype) + # Create a complex number + real_part = op.Unsqueeze(casted, axes=[-1]) + imag_part = op.Expand(0.0, op.Shape(real_part)) + result = op.Concat(real_part, imag_part, axis=-1) + else: + # Cast to real numbers + result = op.Cast(a, to=dtype) + + return result def prims_copy_strided(a: TensorType, stride: INT64) -> TensorType: