Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. splash attention supports GQA so don't repeat kv proj. When Q and KV have different head, splach attention takes care of it by itself. https://github.com/jax-ml/jax/blob/7b9914d711593dca8725d46aa1dadb2194284519/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py#L934 2. Use splash attention lazy mask instead of jnp mask, which is O(T^2). The memory for jnp mask is O(T^2), which almost negates the benefits of reducing HBM communication with flash attention. Let’s use splash attention lazy mask, which lazily generates causal masks. In addition, pallas supports CPU simulation (interpret=True), so use same pallas kernel on CPU, which makes it easier to debug the code. * Benchmark: on TPUv4, (model_dim/heads/kv_heads/seq_len). AS-IS ------------------------------------------------------------------------------------------------ Benchmark Time CPU Iterations HBM ------------------------------------------------------------------------------------------------ FlashAttentionBenchmark/256/2/2/512 3.10 ms 1.69 ms 346 2.45M FlashAttentionBenchmark/2048/16/2/1024 4014 ms 373 ms 1 4.27M FlashAttentionBenchmark/4096/16/2/1024 3822 ms 335 ms 1 6.17M FlashAttentionBenchmark/4096/16/2/4096 4230 ms 1533 ms 1 133.34M FlashAttentionBenchmark/4096/16/2/8192 8233 ms 5481 ms 1 389.42M FlashAttentionBenchmark/4096/16/2/32768 93024 ms 88780 ms 1 1.50G This PR saves both memory and computation. In long context, speed-up is significant (9x). ------------------------------------------------------------------------------------------------ Benchmark Time CPU Iterations HBM ------------------------------------------------------------------------------------------------ FlashAttentionBenchmark/256/2/2/512 3.25 ms 1.75 ms 290 1.43M FlashAttentionBenchmark/2048/16/2/1024 4130 ms 191 ms 1 3.18M FlashAttentionBenchmark/4096/16/2/1024 3734 ms 191 ms 1 4.18M FlashAttentionBenchmark/4096/16/2/4096 3026 ms 297 ms 1 68.46M FlashAttentionBenchmark/4096/16/2/8192 3410 ms 596 ms 1 168.41M FlashAttentionBenchmark/4096/16/2/32768 10833 ms 6588 ms 1 580.48M
- Loading branch information