Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binary KQ mask #28

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft

Binary KQ mask #28

wants to merge 13 commits into from

Conversation

ikawrakow
Copy link
Owner

This PR is another attempt to improve performance for large contexts, see #25

Basically, when we want to process a very long context, the KQ mask, which is stored as f32 (or f16, if using flash attention), becomes quite significant in size. If running on the GPU, the cost for copying the KQ mask to the GPU (the mask is created on the host CPU) becomes non-negligible. If running on a CPU that has limited memory bandwidth (basically all x86 or x86_64), the KQ mask may not fit in the cache, or if it does fit it reduces the cache available for other data by a significant amount, which results in a measurable impact on the performance of the SOFT_MAX (or the new fused SOFT_CAP_MAX) operation. Hence, it will be desirable to reduce the size of the KQ mask.

If not using ALiBi (basically almost always these days), the KQ mask stored 2 values: 0, -INFINITY. It can therefore be represented as a binary mask, thus reducing its size by a factor of 32.

This PR adds an option to use a binary KQ mask. It is off by default as not all platforms are implemented, but can be turned on using -bkq or --binary-kq on the command line. This will have no effect if flash attention is used (KQ mask remains f16 as before). If turned on but not supported by the back-end (non-AVX512 CPUs), the program will assert and terminate.

I see 3-5% performance gains on CUDA and a Ryzen-7950X CPU for a context of 32k tokens, and about 2-3% on Metal for a context of 16k. So, nothing earth-shattering. and hence not quite convinced to merge it.

Here we get a small speedup: Gemma-2-2b and 32k context
is ~4% faster on Zen4. But on Zen4 we can use
  _mm512_mask_mul_ps(-inifnity, mask, s_after, tanh(x*s_before))
to scale and apply mask in a single op that has the same
latency and throughput as _mm512_mul_ps. Combined with reducing
memory loads for the mask represented as fp32 (or fp16), this
gives us some performance improvement for very large masks (contexts).

It will be much more tricky on the other platforms that do not
have masked instructions.
Relatively painless to implement for soft_max and soft_cap_max.
We gain 11.5% for LLaMA-8B and ~14% for Gemma-2-2b at 32k tokens.
The KQ mask is prepared on the CPU and copied to the GPU, so
my guess is that most of it comes from the 32X reduction in the
amount of data being copied to the GPU.

TODO: flash attention
For now just soft_cap_max. On Gemma2-9b I'm observing a
~2% speedup for context of 16k tokens.
I need to redo this with better templates.
It is a pain to implement binary mask to 32-bit value conversion on
NEON and AVX2, so I decided to make the binary mask optional

There is also a commented out (and not working) attempt for NEON
in this commit.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants