-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
onnxruntime 1.17.0 - fp16 model of inswapper causing render issues #19437
Comments
I can reproduce the issue. Let me try dump node inputs/outputs and see which operator might cause result change. |
The difference is caused by #17953. Some Cast node is no longer removed in ORT 1.17.0. For example, the Cast node after Mul that caused the overflow can be removed safely using offline tool: The fix shall be done in the fp16 conversion tool, which shall not add such extra Cast nodes that might cause overflow issue. Some simple post processing like the following shall be enough:
Snippet of example change of the run.py script that can run in ORT 1.17:
|
For CPU execution provider, it is better to run fp32 model. It is because CPU cannot run fp16 in most computation operators, and need to convert to fp32 to run those operators (You can save the optimized model from CPU provider, and you can understand that). If you run some benchmark, fp32 model shall be faster than fp16 model in CPU. |
Do you suggest waiting for a fix or should we re-convert the model with suggested tweaks? |
@henryruhs, please re-convert the model. That's the fastest way to walkaround the issue. The fix I mentioned shall be done in onnxconverter-common, and you can track the status of here: microsoft/onnxconverter-common#271. |
@tianleiwu We tried |
@henryruhs, I verified that the issue was resolved using the code snippet I provided above. Note that the list of initializer might changed so you might need use name to find initializer for INSWAPPER_MATRIX. |
@tianleiwu Are you using the CUDA 12.2 version of onnxruntime 1.17.x ??? |
Yes. I think it shall be good with CUDA 11.8 version of ORT 1.17.x too when extra Cast nodes are removed from onnx model. |
@tianleiwu First, thanks for your patience. I revisited the issue again and figured out that the "fixed" version does indeed work when using the "original / before fixed" model initializer. That being said, it seems to mess up the shape or internals once the cascaded_cast_nodes have been removed.
I can verify the report with this code - it works under CUDA 12:
File: https://github.com/henryruhs/onnxruntime-fp16-issue/raw/master/inswapper_initializer.npy Not sure what we can do from here, I wish we could undo the changes from #17953. |
You can get the initializer by name (assume that the initializer name does not change). For #17953, we will keep it. Maybe we can add an option to remove cascaded cast nodes as a native optimizer so that user can use it (without using python script). |
I fixed it be adding the initializer back afterwards: import onnx
from onnx import numpy_helper
from onnxruntime.transformers.onnx_model import OnnxModel
import numpy as np
PATH = '.assets/models/'
SWAPPER_MODEL_PATH = PATH + 'inswapper_128_fp16.onnx'
model = onnx.load(SWAPPER_MODEL_PATH)
onnx_model = OnnxModel(model)
onnx_model.remove_cascaded_cast_nodes()
onnx_initializer = model.graph.initializer[-1]
INSWAPPER_INITIALIZER = np.load("initializer.npy")
new_initializer_array = np.array(INSWAPPER_INITIALIZER)
model.graph.initializer.append(numpy_helper.from_array(new_initializer_array, name="initializer"))
onnx_model.save_model_to_file(PATH + "inswapper_128_fp16_v2.onnx", use_external_data_format=False, all_tensors_to_one_file=True) |
Thanks again for the support |
Describe the issue
Since we updated to
onnxruntime==1.17.0
thefloat16
version of the inswapper model stopped working and causes broken results depending in the integration.Falling back to
onnxruntime==1.16.3
resolves the issue. It seems to be broken for CPU and CUDA but works for TensorRT.Distorted face (cuda)
Face box being black (cpu)
To reproduce
I created a dedicated repository to reproduce the issue and convert the model.
https://github.com/henryruhs/onnxruntime-fp16-issue
Urgency
Not sure how to define urgency, but this effects thousand of users as our project (FaceFusion) is kinda popular.
Platform
Linux
OS Version
Ubuntu 22 LTS
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.17.0
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU, CUDA
Execution Provider Library Version
No response
The text was updated successfully, but these errors were encountered: