Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect/Garbage Responses for Llama-2-7b-hf with INT4 GPTQ/RTN Asymmetric Quantization #19450

Closed
VishalX opened this issue Feb 7, 2024 · 16 comments
Assignees
Labels
quantization issues related to quantization release:1.17.0

Comments

@VishalX
Copy link

VishalX commented Feb 7, 2024

Describe the issue

I am trying to quantize and run Llama-2-7b-hf model using the example here.

I was able to successfully generate the int4 model with GPTQ quantization by running below command.
Settings:

Namespace(model_input='.\\llama2-7b-fp32\\', model_output='.\\Llama-2-7b-hf-gptq-asym', benchmark=False, quantize=True, batch_size=1, workspace='nc_workspace', algorithm='GPTQ', pad_max=196, seqlen=2048, tasks=['winogrande', 'copa', 'piqa', 'rte', 'hellaswag', 'openbookqa', 'lambada_openai', 'lambada_standard', 'wikitext'], dataset='NeelNanda/pile-10k', block_size=32, is_symmetric=False, accuracy_level=0, sampling_size=8)

However, when I try to run on CPU, I get garbage results for any prompt.

- Prompt: ONNX Runtime is
- Response: ONNX Runtime is  prisoner categorieпута Clientública одногоúblicaública одногоúblicaúblicaúblicapplyúblicaúblicaúblicaúblicaúblicaúblicaúblicażeública geometricúblicażeúblicaúblicaúblicaúblicaúblicaúblicaúblicaúblicaúblicaுúblicaúblicaúblicaże zou[ întRunública Stim cruelF

- Prompt: I want to book a vacation to Hawaii. First, I need to
- Response: I want to book a vacation to Hawaii. First, I need to Statusifier liesStatusifierDOCTYPEissenschaft schedulecmpyed optyed optultan")yed opt diferenелісляcompos into")ultan intoultan optultan \( into oderifierultan rappresentultanел diferenyedyedམła intoyed into")cloudflareел

- Prompt: A good workout routine is
- Response: A good workout routine is 今设 gewesen gewesenісляwardwardwardward musical pueblo gewesen gewesen gewesen gewesenove gewesenoveісля instant zouwardxisісляwardісля instantoveRemoteісля gewesen только estaven толькоxis instantіслярия Wahl только zou서іслярияottiottiaba

- Prompt: How are astronauts launched into space?
- Response: How are astronauts launched into space? emarkemarkemark기 Wahl------+ел기ел기기yed finsелeringелłyyed finsyedелел기othy기 fatyed기temperaturen기기temperaturen thouісляtemperaturen기othy기yed Agutemperaturenелелел thouелinental

Similar output is observed with RTN Asymmetric INT4 model as well.

To reproduce

Following onnxruntime-inference-examples WOQ README.

python main.py --model_input .\llama2-7b-fp32\ --model_output .\Llama-2-7b-hf-gptq-asym --accuracy_level 0 --quantize --algorithm GPTQ

I have used the inference code from here with some changes mentioned below

use_fp16 = False  # True when KV cache inputs/outputs are in float16
use_buffer_share = False  # True when --use_gqa was passed during export
device = torch.device("cpu")  # running on CPU

Urgency

No response

Platform

Windows

OS Version

Windows 11

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

v1.17.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@VishalX VishalX changed the title Incorrect/Garbage Responses for Llama-2-7b-hf with INT4 GPTQ/RTN Symmetric Quantization Incorrect/Garbage Responses for Llama-2-7b-hf with INT4 GPTQ/RTN Asymmetric Quantization Feb 7, 2024
@VishalX
Copy link
Author

VishalX commented Feb 14, 2024

Any update on this? @yufenglee / @kunal-vaishnavi

@yufenglee
Copy link
Member

Hi @VishalX, could you please try quantizing the model directly with command like:
python -m onnxruntime.quantization.matmul_4bits_quantizer? And is your model a fine-tune model or the original llama2?

@VishalX
Copy link
Author

VishalX commented Feb 20, 2024

Hey @yufenglee,
I'm using original llama2: meta-llama/Llama-2-7b, exported to ONNX using below command.

python -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b

Let me try quantizing it using command line. I hope below command is good enough:

python -m onnxruntime.quantization.matmul_4bits_quantizer --input_model <path-to-fp32-model> --output_model <path-to-int4-model> --block_size 32 --symmetric False --accuracy_level 0 --verbose

@yufenglee
Copy link
Member

Hey @yufenglee, I'm using original llama2: meta-llama/Llama-2-7b, exported to ONNX using below command.

python -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b

Let me try quantizing it using command line. I hope below command is good enough:

python -m onnxruntime.quantization.matmul_4bits_quantizer --input_model <path-to-fp32-model> --output_model <path-to-int4-model> --block_size 32 --symmetric False --accuracy_level 0 --verbose

Yes, please try this command line.

@VishalX
Copy link
Author

VishalX commented Feb 20, 2024

@yufenglee
Just tried and printed the parsed args. The --symmetric flag isn't getting updated to False.

Namespace(input_model='llama2-7b-fp32/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32_opt.onnx', output_model='bw_asym/model.onnx', block_size=32, symmetric=True, accuracy_level=0, verbose=True, nodes_to_exclude=[])

I'll fix this locally and try again.

@VishalX
Copy link
Author

VishalX commented Feb 20, 2024

@yufenglee I get garbage outputs with Asymmetric.

📌 Running on Windows

With Asymmetric Quantization (block size = 32, accuracy level 0)

- Prompt: ONNX Runtime is
- Response: ONNX Runtime is iformesuttтку il prec Sé WagnertziblizioneestoneTItzbero blindkretPortailbez https blindec demselbenestone ris factioromp totalitéius blindattle Meteor yourselfkretcepметspanartaestonekretkretkretshotkre典estonewod

- Prompt: I want to book a vacation to Hawaii. First, I need to
- Response: I want to book a vacation to Hawaii. First, I need to rameign commissioninale dispogo Magn commissioncyk Transport refterraLABterra blindinalebez attzk mieszkańxaizonanonranoranoranorano wojewalyranoletterano prüitorActivityThread Bayerlop Bayerbers Helsixenemploerei BayerÉt запаizon

- Prompt: A good workout routine is
- Response: A good workout routine is embly wat fig immterneient j externas graonymgor面 blindisserрахcribe dispščececa tippenasutaletartalettetzлезasonsREATEetheabgerufengorrefsifactTCciasTCemploTCissantTCflutterTC consultéTCiella

- Prompt: How are astronauts launched into space?
- Response: How are astronauts launched into space? burgoромаometric Knoabeierni Feldggi blindбурython Herzterneclipseillébulletlauki Mann GlMRrefixodybergerlijk Rub sor [" Rosa fetetz Transportifact TransportромаTC blind ursottiorig yletteatzrezrez Voor Anderson

With Symmetric Quantization (block size = 32, accuracy level 0)

- Prompt: ONNX Runtime is
- Response: ONNX Runtime is 100% open source and is available on GitHub. Hinweis: Die folgende Seite ist nur auf Englisch verfügbar.
The ONNX Runtime is a C++ library that allows you to run models in ON

- Prompt: I want to book a vacation to Hawaii. First, I need to
- Response: I want to book a vacation to Hawaii. First, I need to 1) find a hotel, 2) find a flight, and 3) find a rental car.
I've been to Hawaii before, so I know what I like. I've stayed at the Out

- Prompt: A good workout routine is
- Response: A good workout routine is 30 minutes of cardio and 30 minutes of weight training. nobody is going to get ripped from just cardio.
I'm not sure if you're being sarcastic or not, but I'

- Prompt: How are astronauts launched into space?
- Response: How are astronauts launched into space? 1. nobody knows.
How are astronauts launched into space?
1. nobody knows.
2. nobody knows.
3. nobody knows.
4. nobody knows.
5. nobody knows.
6

@VishalX
Copy link
Author

VishalX commented Feb 20, 2024

Interestingly, If I run the same model on Linux (Ubuntu 18.04), I get somewhat better results with Asymmetric model but I still see non-English sentence/words within Responses. However, the responses with Symmetric quantized model are matching on Windows and Linux.

With Asymmetric Quantization (block size = 32, accuracy level 0)

📌 Running on Linux

- Prompt: ONNX Runtime is 
- Response: ONNX Runtime is 100% open source and free to use. Hinweis: Die folgende Liste enthält nur die wichtigsten Features.
ONNX Runtime is a cross-platform, open source, and free to use runtime for

- Prompt: I want to book a vacation to Hawaii. First, I need to 
- Response: I want to book a vacation to Hawaii. First, I need to 1) find a good travel agent, 2) find a good hotel, and 3) find a good airline.
I've been to Hawaii before, and I know that I want to stay in Waik

- Prompt: A good workout routine is 
- Response: A good workout routine is 30 minutes of exercise, 3 times a week. Hinweis: Die Angaben über die Nährstoffe sind nur Richtwerte.
The best way to get a good workout routine is to start with a 

- Prompt: How are astronauts launched into space? 
- Response: How are astronauts launched into space? 1. Hinweis: Die Antworten sind in der Regel nicht so einfach wie sie aussieht.
How are astronauts launched into space?
How are astronauts launched into space? 1.

@yufenglee
Copy link
Member

I can repro the issue locally.

  • For the Symmetric quantization, the "Hinweis: Die folgende Seite ist nur auf Englisch verfügbar." in the 1st prompt is German and means "Note: The following page is only available in English.". I ran the same model with CUDA EP and get same result. It is caused by model quantization accuracy.
  • For the asymmetric quantization on Windows, We need to investigate more.

@yufenglee
Copy link
Member

And for the issue in the original post, do you run on Windows or Linux?

@VishalX
Copy link
Author

VishalX commented Feb 21, 2024

And for the issue in the original post, do you run on Windows or Linux?

@yufenglee I ran on Windows.

@yufenglee
Copy link
Member

yufenglee commented Feb 22, 2024

@VishalX, just FYI. It turns out something wrong with mmap on windows. If I turns off mmap, Asymmetric works on Windows. You can try it out with this branch if you want to: https://github.com/microsoft/onnxruntime/tree/yufeng/hot_fix. I will investigate more.

@VishalX
Copy link
Author

VishalX commented Feb 22, 2024

@VishalX, just FYI. It turns out something wrong with mmap on windows. If I turns off mmap, Asymmetric works on Windows. You can try it out with this branch if you want to: https://github.com/microsoft/onnxruntime/tree/yufeng/hot_fix. I will investigate more.

Tried this fix, I get the exact same response as I am getting on Linux for Asym.

@VishalX
Copy link
Author

VishalX commented Feb 22, 2024

@yufenglee,
I tried Asymmetric BlockWise, RTN & GPTQ, with the above fix. Responses for all these include German sentences/words.
Do you think this is due to quantization loss only?

The earlier published numbers from: #17390, suggests GPTQ (G32Asym accuracy = 0.7326 for Lambada_openai.

With the responses like these, I'm not so sure that the above can be reproduced. I'll generate the score and see what I get.
However, there could be some other issue as well? What do you think?

@yufenglee
Copy link
Member

@yufenglee, I tried Asymmetric BlockWise, RTN & GPTQ, with the above fix. Responses for all these include German sentences/words. Do you think this is due to quantization loss only?

The earlier published numbers from: #17390, suggests GPTQ (G32Asym accuracy = 0.7326 for Lambada_openai.

With the responses like these, I'm not so sure that the above can be reproduced. I'll generate the score and see what I get. However, there could be some other issue as well? What do you think?

@VishalX, it would be great if you can try reproducing and get the score.

@VishalX
Copy link
Author

VishalX commented Feb 23, 2024

@yufenglee, I tried Asymmetric BlockWise, RTN & GPTQ, with the above fix. Responses for all these include German sentences/words. Do you think this is due to quantization loss only?
The earlier published numbers from: #17390, suggests GPTQ (G32Asym accuracy = 0.7326 for Lambada_openai.
With the responses like these, I'm not so sure that the above can be reproduced. I'll generate the score and see what I get. However, there could be some other issue as well? What do you think?

@VishalX, it would be great if you can try reproducing and get the score.

@yufenglee I can reproduce the published numbers with minor difference.

Task Version Metric Value Stderr
lambada_openai 0 ppl 3.5593 ± 0.0714
acc 0.7314 ± 0.0062

Accuracy for lambada_openai is: 0.7314185911119736

I'll check for Wikitext as well.

@VishalX
Copy link
Author

VishalX commented Feb 23, 2024

For Wikitext

Task Version Metric Value
wikitext 1 word_perplexity 9.1113
byte_perplexity 1.5116
bits_per_byte 0.5961

This also looks close to the published numbers.

yufenglee added a commit that referenced this issue Feb 25, 2024
### Description
<!-- Describe your changes. -->
Windows memory map casts mapped_offset to DWORD directly. It will be
truncated if it is larger than 2^32-1. We need to set high
dwFileOffsetHigh for this case.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

The bug was found from #19450
@VishalX VishalX closed this as completed Feb 26, 2024
@sophies927 sophies927 added the quantization issues related to quantization label Feb 29, 2024
maggie1059 pushed a commit that referenced this issue Mar 8, 2024
### Description
<!-- Describe your changes. -->
Windows memory map casts mapped_offset to DWORD directly. It will be
truncated if it is larger than 2^32-1. We need to set high
dwFileOffsetHigh for this case.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

The bug was found from #19450
maggie1059 added a commit that referenced this issue Mar 11, 2024
### Description
<!-- Describe your changes. -->
Windows memory map casts mapped_offset to DWORD directly. It will be
truncated if it is larger than 2^32-1. We need to set high
dwFileOffsetHigh for this case.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

The bug was found from #19450

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Yufeng Li <[email protected]>
YUNQIUGUO pushed a commit that referenced this issue Mar 21, 2024
### Description
<!-- Describe your changes. -->
Windows memory map casts mapped_offset to DWORD directly. It will be
truncated if it is larger than 2^32-1. We need to set high
dwFileOffsetHigh for this case.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

The bug was found from #19450
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
quantization issues related to quantization release:1.17.0
Projects
None yet
Development

No branches or pull requests

5 participants