Replies: 3 comments 3 replies
-
Hi @BBuf, flash_attention backend is great and we would love to have it.
cc @yzh119 for the bad case. |
Beta Was this translation helpful? Give feedback.
-
Hi @BBuf It's not a fair comparison because you use prefill attention on page table (BatchPrefillWithPagedKVCache) for flashinfer while using ragged tensor (flash_attn_varlen_func) for flashattention. The BatchPrefillWithRaggedKVCache in flashinfer should have the same semantics and (similar) implementation to flash_attn_varlen_func. I remember sglang uses both BatchPrefillWithRaggedKVCache and BatchPrefillWithPagedKVCache but I'm not sure about how sglang dispatches these two implementations. @Ying1123 can you further clarify? |
Beta Was this translation helpful? Give feedback.
-
The test above was conducted in a custom-built PyTorch Docker container. After switching to an NGC Docker today, I did not find a significant difference in the computation time between flashinfer and flash_attention in this scenario. Flashinfer was only about 10% slower. Below are the computation times for the three different cases as shown in nsys: vllm flash_attention, sglang origin, and sgalng always ragged tensor. vllm flash_attention: total 81+52=133us. sglang origin: total 43 + 104 = 147us. sgalng always ragged tensor total 71 + 59 = 130us. In summary, when deploying Llama3-8b on a single 4090 GPU and using qps=8 for requests on my dataset, I observed a significant difference in the computation time between flash attention and flashinfer, which was due to the fact that I was running in a manually compiled PyTorch Docker container. After switching to another NGC Docker, this difference became much smaller. Flashinfer's attention computation was only about 10% slower than flash attention. Furthermore, after making flashinfer always use ragged tensors, I found that their attention computation times were nearly the same. Therefore, I personally feel that there is no need to add a new backend for flash attention. We can close this discussion and allow users to choose whether to use ragged tensors entirely. Based on the test results, the impact on throughput, ttf, and tpo is also relatively small. Specifically, ttf decreased from 0.081s to 0.078s, while there were no significant changes in tpo and throughput. |
Beta Was this translation helpful? Give feedback.
-
I tested the LLaMA3-8B model on a single 4090 gpu and found that the flashinfer calculation of attention is much slower than the flash attention library. I would like to ask if sglang can support the flash attention backend.
sglang flashinfer:
vllm flash_attention:
Flashinfer would be two times slower than flash_attention in here. Fortunately, we did not observe flashinfer being slower than flash attention in other models such as qwen2-72b. Instead, sglang's overall throughput is much stronger than vllm.
I want to know is there any possible to support flash_attention backend? I believe that in certain GPU architectures such as 4090 or certain shapes, FlashInfer has performance bad cases.
I also had a try, code here:
But after starting the service, the program crashes due to illegal CUDA memory access after running a few forward_decode. The logs are below. I'm not quite sure what happened, so I need help.
serving command is:
Beta Was this translation helpful? Give feedback.
All reactions