Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support causal flash attention #2425

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Support causal flash attention #2425

wants to merge 3 commits into from

Conversation

jopperm
Copy link
Contributor

@jopperm jopperm commented Oct 4, 2024

This PR adds support for causal FA:

  • Keeps encoding on row-vector tensor operations, as must be left untouched when lowering to the SIMT program.
  • Extends the pattern matching helper that determines whether a tensor is transposed, to look through advance operations. (The second attention loop uses a transposed tensor pointer that is tt.advance'd between the loops.)

Signed-off-by: Julian Oppermann <[email protected]>
@jopperm jopperm self-assigned this Oct 4, 2024
@jopperm
Copy link
Contributor Author

jopperm commented Oct 4, 2024

For D_HEAD=128, we're getting tt.make_range ops that are smaller than the subgroup size; I don't know how to lower these yet. Checked codegen, seems that no special handling is needed.

@jopperm jopperm changed the title (Almost) support causal flash attention Support causal flash attention Oct 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[#6 Attention Performance] extend attention support for Causal = True
2 participants