Skip to content

Commit

Permalink
Optimize TPU Flash Attention
Browse files Browse the repository at this point in the history
1. splash attention supports GQA so don't repeat kv proj.

When Q and KV have different head, splach attention takes care of it by itself.
https://github.com/jax-ml/jax/blob/7b9914d711593dca8725d46aa1dadb2194284519/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py#L934

2. Use splash attention lazy mask instead of jnp mask, which is O(T^2).

The memory for jnp mask is O(T^2), which almost negates the benefits of
reducing HBM communication with flash attention. Let’s use splash attention
lazy mask, which lazily generates causal masks.

In addition, pallas supports CPU simulation (interpret=True), so use same
pallas kernel on CPU, which makes it easier to debug the code.

* Benchmark: on TPUv4, (model_dim/heads/kv_heads/seq_len).

AS-IS
------------------------------------------------------------------------------------------------
Benchmark                                       Time             CPU   Iterations           HBM
------------------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512           3.10 ms         1.69 ms          346        2.45M
FlashAttentionBenchmark/2048/16/2/1024        4014 ms          373 ms            1        4.27M
FlashAttentionBenchmark/4096/16/2/1024        3822 ms          335 ms            1        6.17M
FlashAttentionBenchmark/4096/16/2/4096        4230 ms         1533 ms            1      133.34M
FlashAttentionBenchmark/4096/16/2/8192        8233 ms         5481 ms            1      389.42M
FlashAttentionBenchmark/4096/16/2/32768      93024 ms        88780 ms            1        1.50G

This PR saves both memory and computation. In long context, speed-up is significant (9x).
------------------------------------------------------------------------------------------------
Benchmark                                       Time             CPU   Iterations           HBM
------------------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512           3.25 ms         1.75 ms          290        1.43M
FlashAttentionBenchmark/2048/16/2/1024        4130 ms          191 ms            1        3.18M
FlashAttentionBenchmark/4096/16/2/1024        3734 ms          191 ms            1        4.18M
FlashAttentionBenchmark/4096/16/2/4096        3026 ms          297 ms            1       68.46M
FlashAttentionBenchmark/4096/16/2/8192        3410 ms          596 ms            1      168.41M
FlashAttentionBenchmark/4096/16/2/32768      10833 ms         6588 ms            1      580.48M
  • Loading branch information
ds-hwang committed Nov 18, 2024
1 parent 594313d commit 11b8a07
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 78 deletions.
24 changes: 19 additions & 5 deletions axlearn/common/flash_attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
make_segment_mask,
)
from axlearn.common.config import config_class
from axlearn.common.flash_attention import tpu_attention
from axlearn.common.flash_attention.utils import (
MultiHeadAttentionImpl,
flash_attention_implementation,
Expand Down Expand Up @@ -169,10 +170,6 @@ def _compute_attention(
cfg = self.config
backend = self._backend()

# Repeats key/value heads dim if necessary.
k_proj = self._repeat_kv_heads(k_proj)
v_proj = self._repeat_kv_heads(v_proj)

batch, target_len, num_heads, _ = q_proj.shape
_, source_len, _, _ = k_proj.shape

Expand Down Expand Up @@ -228,7 +225,18 @@ def _compute_attention(
f"{k_proj.shape[1]} for correctly supported GPU flash attention usage."
)

if backend == "tpu":
if backend == "cpu" and not tpu_attention.check_tpu_splash_attention(
query=q_proj,
key=k_proj,
has_mask=bool(cfg.mask),
segment_ids=segment_ids,
has_bias=(attention_logit_biases is not None),
):
backend = "xla"

if backend in ("tpu", "cpu"):
# Splash attention needs to know sliding_window_size.
mask_fn = cfg.mask
assert q_proj.shape[1] % cfg.tpu_block_size == 0, (
f"Target seq len {q_proj.shape[1]} must be "
f"divisible by block size {cfg.tpu_block_size}."
Expand Down Expand Up @@ -259,6 +267,12 @@ def _compute_attention(
attention_logit_biases, attention_logit_biases_spec
)

# Note: splash attention supports GQA natively.
if backend not in ("tpu", "cpu"):
# Repeats key/value heads dim if necessary.
k_proj = self._repeat_kv_heads(k_proj)
v_proj = self._repeat_kv_heads(v_proj)

# Scale query and key.
q_proj = self.scale_query(q_proj)
k_proj = self.scale_key(k_proj)
Expand Down
10 changes: 9 additions & 1 deletion axlearn/common/flash_attention/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized
from absl.testing import absltest, parameterized
from jax.experimental import mesh_utils
from jax.sharding import Mesh

Expand Down Expand Up @@ -91,6 +91,7 @@ def _prepare_layers(
sliding_window_size,
inference=False,
set_layer_bias_recursively=False,
tpu_block_size=512,
):
hidden_dim = num_heads * per_head_dim
kwargs = dict(
Expand All @@ -110,6 +111,7 @@ def _prepare_layers(
.set(
mha_dim_to_partition_spec=default_mha_dim_to_partition_spec(mesh_axis_names),
output_dim_to_partition_spec=default_output_dim_to_partition_spec(mesh_axis_names),
tpu_block_size=tpu_block_size,
)
)
if inference:
Expand Down Expand Up @@ -378,7 +380,9 @@ def test_forward(
mesh_axis_names=mesh_axis_names,
causal=causal,
sliding_window_size=sliding_window_size,
tpu_block_size=128,
)

# pylint: disable-next=protected-access
if test_layer._backend() == "gpu" and query_len_multiplier != 1:
pytest.skip(
Expand Down Expand Up @@ -734,3 +738,7 @@ def test_extend_step(
atol=2e-2,
)
jax.clear_backends()


if __name__ == "__main__":
absltest.main()
Loading

0 comments on commit 11b8a07

Please sign in to comment.