Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handling bfloat16 constant propagation (#1484)
Just a temporary workaround for #1471 for experimentation/discussion. It looks like the issue is that bfloat16 tensor constants are represented as float32 numpy arrays (in ONNX itself), when converted to numpy array. In the context of constant-propagation, this means that we cannot rely solely on the numpy value's dtype to figure out the ONNX type. The hack below suppresses constant-propagation for bfloat16 constants: partially because of the above reason, and partially since I am yet unclear if this convention is supported by the onnx reference implementation (or ORT), etc. Assuming the backend supports it, we can try other alternative solutions too. One possibility is to simply suppress constant-propagation if the output-types are unknown (in the onnx model).
- Loading branch information