-
Notifications
You must be signed in to change notification settings - Fork 145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unable to export Microsoft/Phi-3-small-8k-instruct ONNX model on CPU (Ubuntu 22.04.4 LTS) #908
Comments
Could you please try:
BTW, those are the combination the tool support now: |
Thanks for the response @yufenglee To export Microsoft/Phi-3-small-8k-instruct onnx CPU model, CUDA is mandatory. we can't export it using CPU(because flash attention needs CUDA) . But we can run the resultant model on the CPU exported using -e cuda ? is that right? And is there any way to export the model with input_ids and attention_mask as inputs (without position_ids, present, and past key values) and obtain logits only as the output? |
CUDA is required to export the model for the reason you mentioned. '-e cuda' specifies that the exported ONNX is targeted to run with OnnxRuntime cuda EP. You can use '-e CPU' to export the ONNX model to run with ORT CPU. '-p fp16/fp32/int4' specifies the data type of the ONNX model. positions_ids, present/past key/values are required inputs of the model. We don't have option to ignore them now. However, those inputs/outputs are managed by ORT GenAI API automatically. You can get logits with ORT GenAI API like this after you exporting the model:
|
Here they mentioned Phi-3 small ONNX models can now run on CPU. |
Since that PR was merged, the changes have been added to the latest versions of ONNX Runtime and ONNX Runtime GenAI. You can install the latest stable versions to produce the Phi-3 small ONNX model for CPU instead of needing to build from source.
You can make the following modifications to the model builder to achieve this.
onnxruntime-genai/src/python/py/models/builder.py Lines 495 to 507 in f5af763
onnxruntime-genai/src/python/py/models/builder.py Lines 1378 to 1381 in f5af763
As mentioned above, however, the past and present key-value caches are required to run with ONNX Runtime GenAI.
As mentioned above, Here's how you can get around this issue.
|
thanks for the detailed explanation @kunal-vaishnavi
but this script throws below error. is there any additional modification need to be added in model script?
|
This is an expected error no matter which opset version is used. The generated ONNX models contain operators that are in the
According to the op schema for To get a valid ONNX model, you will need to undo the model builder changes you made so that the past and present key-value caches are added back as inputs and outputs to both the ONNX model and the
Given that you will need the past and present key-value caches for the ONNX model and from reading your inference script, it appears you can use ONNX Runtime GenAI to simplify your inference script and improve model performance. Here is an example inference script for Phi-3 that applies the Phi-3 chat template. For a more general-purpose and simpler inference script, here is another example. You can also swap out ONNX Runtime GenAI's tokenizer with Hugging Face's tokenizer in these examples if you want. For the Phi-3 specific inference script, you can set |
Got it , thanks for the suggestions @kunal-vaishnavi
So, the conclusion is that in CPU, exporting PHI3 Small variants to onnx model without past and present key and values is not valid as those are mandatory inputs to SparseAttention. |
Yes, the LayerNorm-specific error will go away with opset 17 or higher. But since the ONNX model from the model builder always has ops from the
The
|
thanks for the script @kunal-vaishnavi |
Export Microsoft/Phi-3-small-8k-instruct ONNX model on CPU (Ubuntu 22.04.4 LTS)
As per suggestion, I referred to ONNX Runtime Build Documentation and followed the steps below:
However, I encountered the following error:
AssertionError: Flash Attention is not available, but is needed for dense attention.
Detailed Trace:
Note: I verified my build by exporting the Phi-3-min-4k-instruct model successfully.
Additionally, I want to export the model with input_ids and attention_mask as inputs (without position_ids, present, and past key values) and obtain logits as the output. is there any way to achieve it?
Any help from members of the official repository would be greatly appreciated!
The text was updated successfully, but these errors were encountered: