Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Improve the flash attention performance on bottom-up optimization pipeline #2177

Open
chengjunlu opened this issue Sep 10, 2024 · 4 comments · Fixed by #2181 · May be fixed by #2359
Open

[Performance] Improve the flash attention performance on bottom-up optimization pipeline #2177

chengjunlu opened this issue Sep 10, 2024 · 4 comments · Fixed by #2181 · May be fixed by #2359
Assignees

Comments

@chengjunlu
Copy link
Contributor

chengjunlu commented Sep 10, 2024

This issue is to track the new design required for flash-attention on bottom-up optimization pipeline.

Status

The most of the optimization passes has been finished and been checked in llvm-target branch. And all the tasks in the old issue #878 have been finished. The GEMM Triton kernel with block pointer syntax can get the 90% performance of the XeTLA version. There is a promising performance on the flash attention with block pointer by adding simply changes in RewriteBlockPointer pass.

New problem

There are two new problems found in the developing the bottom-up optimization pipeline:

  1. The FP8 flash attention has been supported and need to continue support it. We need some new implementation in lowering tt.load to support FP8 for flash attention.
  2. The RewriteBlockPointer pass generate the code not efficient tracked in the [Performance] Improve the code generated by the RewriteTensorPointer pass. #1766

Plan

To achieve the goals of both performance and functionality on bottom-up phase, we need a new implementation than it is original planed.

  1. Support to fallback to gather/scatter semantic memory accessing in lowering the tt.load operation with the block pointer as memory ptr. (Optionally to support fallback to Intel 1D block IO.)
  2. Remove RewriteBlockPointer totally as the memory accessing operation support to load the block pointer to any layout. (1st step.)

This design also can benefit the new feature as TMA descriptor in future.

@etiotto
Copy link
Contributor

etiotto commented Sep 13, 2024

@chengjunlu what about #950 ? That is prob. needed to reduce reg. pressure. We can track it here ?

@chengjunlu
Copy link
Contributor Author

@chengjunlu what about #950 ? That is prob. needed to reduce reg. pressure. We can track it here ?

Forgot that one. Yes, let's track it here too.

@whitneywhtsang
Copy link
Contributor

whitneywhtsang commented Oct 2, 2024

On agama 996 with latest main (commit 6f89dbe), Triton performance is 40% of XeTLA.
Note: TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 is removed.
Measurements are done on PVC 1100.

Z H N_CTX D_HEAD CAUSAL Triton-TFlops XeTLA-TFlops Triton/XeTLA
1 16 16384 128 FALSE 28.63831 80.13746 36%
1 16 16384 128 TRUE 61.08398 156.2801 39%
1 32 16384 64 FALSE 45.27705 103.8121 44%
1 32 16384 64 TRUE 96.77028 195.821 49%
2 16 8192 128 FALSE 28.07994 79.68677 35%
2 16 8192 128 TRUE 57.68881 154.5856 37%
2 32 8192 64 FALSE 42.40634 102.8581 41%
2 32 8192 64 TRUE 88.55945 191.986 46%
4 16 4096 128 FALSE 27.27483 78.8773 35%
4 16 4096 128 TRUE 53.9501 151.3645 36%
4 32 4096 64 FALSE 41.54896 97.96356 42%
4 32 4096 64 TRUE 82.51222 184.9685 45%
4 48 1024 64 FALSE 42.10615 80.2098 52%
4 48 1024 64 TRUE 64.33444 127.5733 50%
8 16 2048 128 FALSE 26.73036 78.13827 34%
8 16 2048 128 TRUE 48.5033 145.088 33%
8 32 2048 64 FALSE 40.18589 96.59209 42%
8 32 2048 64 TRUE 73.87603 170.4972 43%
16 16 1024 128 FALSE 26.08306 74.89045 35%
16 16 1024 128 TRUE 40.21317 131.044 31%
16 32 1024 64 FALSE 39.02387 90.7835 43%
16 32 1024 64 TRUE 59.31046 149.4855 40%
32 16 512 128 FALSE 24.72351 70.58867 35%
32 16 512 128 TRUE 32.41974 104.963 31%
32 32 512 64 FALSE 36.66681 80.65018 45%
32 32 512 64 TRUE 45.07732 111.693 40%
            GEOMEAN 40%

@whitneywhtsang
Copy link
Contributor

Triton/XeTLA improves 40%->43% by removing all environment variables on agama 996.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment