Skip to content

Commit

Permalink
Handle complex types in convert_element_type | fix(torchlib) (#1124)
Browse files Browse the repository at this point in the history
Previously when we call e.g. `convert_element_type(a, dtype=complex64)`,
we create a node with `COMPLEX64` output. This is not what we want
because we always the real representation of a complex number in the
ONNX graphs.

This change updates the logic to handle when the specified dtype is a
complex type, in which case we add an additional axis to the end of the
tensor and fill it with zeros to represent the imaginary part.

Tested locally:

```
>>> prims_convert_element_type(np.zeros(1), 14)
array([[0., 0.]], dtype=float32)
>>> prims_convert_element_type(np.zeros(1), 15)
array([[0., 0.]])
```

Fix #1122

---------

Co-authored-by: Ti-Tai Wang <[email protected]>
  • Loading branch information
justinchuby and titaiwangms authored Nov 3, 2023
1 parent aab5517 commit 5b50753
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@

from onnxscript import INT64
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TTensor
from onnxscript.function_libs.torch_lib.tensor_typing import RealType, TTensor
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType
from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT, TensorType

COMPLEX64_TYPE = COMPLEX64.dtype
COMPLEX128_TYPE = COMPLEX128.dtype


def prims_abs(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -217,10 +220,28 @@ def prims_conj_physical(self: TensorType) -> TensorType:


@torch_op("prims::convert_element_type")
def prims_convert_element_type(a: TensorType, dtype: int) -> TensorType:
def prims_convert_element_type(a: RealType, dtype: int) -> RealType:
"""convert_element_type(Tensor a, ScalarType dtype) -> Tensor"""

return op.Cast(a, to=dtype)
if dtype == COMPLEX128_TYPE:
# Cast to the real representation of the complex type
casted = op.Cast(a, to=DOUBLE.dtype)
# Create a complex number
real_part = op.Unsqueeze(casted, axes=[-1])
imag_part = op.Expand(op.Cast(0.0, to=DOUBLE.dtype), op.Shape(real_part))
result = op.Concat(real_part, imag_part, axis=-1)
elif dtype == COMPLEX64_TYPE:
# Cast to the real representation of the complex type
casted = op.Cast(a, to=FLOAT.dtype)
# Create a complex number
real_part = op.Unsqueeze(casted, axes=[-1])
imag_part = op.Expand(0.0, op.Shape(real_part))
result = op.Concat(real_part, imag_part, axis=-1)
else:
# Cast to real numbers
result = op.Cast(a, to=dtype)

return result


def prims_copy_strided(a: TensorType, stride: INT64) -> TensorType:
Expand Down

0 comments on commit 5b50753

Please sign in to comment.