Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handle complex types in convert_element_type | fix(torchlib) (#1124)
Previously when we call e.g. `convert_element_type(a, dtype=complex64)`, we create a node with `COMPLEX64` output. This is not what we want because we always the real representation of a complex number in the ONNX graphs. This change updates the logic to handle when the specified dtype is a complex type, in which case we add an additional axis to the end of the tensor and fill it with zeros to represent the imaginary part. Tested locally: ``` >>> prims_convert_element_type(np.zeros(1), 14) array([[0., 0.]], dtype=float32) >>> prims_convert_element_type(np.zeros(1), 15) array([[0., 0.]]) ``` Fix #1122 --------- Co-authored-by: Ti-Tai Wang <[email protected]>
- Loading branch information