Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Where is the code about "remaining layers use faster half precision accumulate"? #10

Open
goldhuang opened this issue Sep 3, 2024 · 5 comments

Comments

@goldhuang
Copy link

Flux diffusion model implementation using quantized fp8 matmul & remaining layers use faster half precision accumulate, which is ~2x faster on consumer devices.
Hello there!
Thanks for sharing your quantization implementation of Flux!
I have a question about "remaining layers use faster half precision accumulate". Could you help to point out the lines that enable "faster half precision accumulate" in the repo?
Thanks in advance!

@aredden
Copy link
Owner

aredden commented Sep 4, 2024

It's the CublasLinear layers. It's a repo I made which allows matmuls to run with half precision accumulate within the matmul kernel- which doubles the tflops for most consumer gpus. The source is here- https://github.com/aredden/torch-cublas-hgemm - so, wherever you see CublasLinear replacements happening- I think it's actually in the float8_quantize.py file, that's where that occurs.

@goldhuang
Copy link
Author

@aredden Thanks for your detailed answer!
I have 2 follow-up questions now.

  1. Why do you only replace linear layers in single/double block with fp8?
  2. Why does CublasLinear only support float16?

@aredden
Copy link
Owner

aredden commented Sep 5, 2024

  1. You can optionally quantize the others by setting "quantize_flow_embedder_layers": true, but it does pretty considerably reduce quality and doesn't add much extra vram or increase it/s. The non-single-or-double-block layers only make up for ~2% of the models actual weights, but have a considerable effect on quality.

  2. Well if you check out the ADA whitepaper, you'll find that the top theoretical tflops for fp16 w/ fp32 accumulate is ~160 for 4090, but 330 for fp16 w/ fp16 accumulate. Unfortunately you cannot use fp16 accumulate with anything other than fp16 tensors, and bf16 cannot be used as accumulation datatype so the only way to achieve those tflops on consumer gpus is via fp16. It's actually the same speed as fp8 matmul!

@spejamas
Copy link

Hey @aredden, will a datacenter GPU (like L40S for example) get any benefit from the cublas swap?

@aredden
Copy link
Owner

aredden commented Oct 31, 2024

Not really- it has enough sram where it gets the same tflops for fp16 w/ fp32 accumulate as it does for fp16 w/ fp16 accumulate. @spejamas

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants