Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform: support sdpa to flash attention kernel conversion #131

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yifeizh2
Copy link
Contributor

@yifeizh2 yifeizh2 commented Jun 13, 2024

Tracking issue #147.

TODO:

  • Check correctness
  • Align performance
  • Allow tuning for default config

@yifeizh2 yifeizh2 changed the base branch from main to zhicong/deep_tile_matmul June 13, 2024 07:52
@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 3 times, most recently from 206fead to 65dfab8 Compare June 14, 2024 03:10
@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 3 times, most recently from 79b277d to d69856f Compare June 27, 2024 01:35
@yifeizh2
Copy link
Contributor Author

yifeizh2 commented Jul 16, 2024

As for performance evaluation, there are two issues

  • deep tiled matmul will not be in effect if the parent op of matmul is scf::forall, so we need to either support this scenario or directly invoke the deep tiled matmul in flash attention kernel
  • 5x performance gap
<style> </style>
SEQ LENGTH / DTYPE graph compiler v1 (ms) v2, block 32, brgemm invoked (ms) Ratio
384 / fp32 8.28744 48.996 5.912079
768 / fp32 39.1193 190.557 4.871176
1536 / fp32 177.446 752.762 4.242203
2304 / fp32 389.382 1837.45 4.718888
3072 / fp32 682.228 3273.3813 4.798075

Next steps for performance alignment is

  • Compare the precise brgemm config used in both cases (v1 v.s. mlir)
  • Perform more detailed performance breakdown

@ZhennanQin
Copy link
Contributor

As for performance evaluation, there are two issues

  • deep tiled matmul will not be in effect if the parent op of matmul is scf::forall, so we need to either support this scenario or directly invoke the deep tiled matmul in flash attention kernel
  • 5x performance gap
<style> </style>

SEQ LENGTH / DTYPE graph compiler v1 (ms) v2, block 32, brgemm invoked (ms) Ratio
384 / fp32 8.28744 48.996 5.912079
768 / fp32 39.1193 190.557 4.871176
1536 / fp32 177.446 752.762 4.242203
2304 / fp32 389.382 1837.45 4.718888
3072 / fp32 682.228 3273.3813 4.798075
Next steps for performance alignment is

  • Compare the precise brgemm config used in both cases (v1 v.s. mlir)
  • Perform more detailed performance breakdown

Please try brgemm instead of matmul, which can provide better performance result.

@yifeizh2
Copy link
Contributor Author

As for performance evaluation, there are two issues

  • deep tiled matmul will not be in effect if the parent op of matmul is scf::forall, so we need to either support this scenario or directly invoke the deep tiled matmul in flash attention kernel
  • 5x performance gap
<style> </style>

SEQ LENGTH / DTYPE graph compiler v1 (ms) v2, block 32, brgemm invoked (ms) Ratio
384 / fp32 8.28744 48.996 5.912079
768 / fp32 39.1193 190.557 4.871176
1536 / fp32 177.446 752.762 4.242203
2304 / fp32 389.382 1837.45 4.718888
3072 / fp32 682.228 3273.3813 4.798075
Next steps for performance alignment is

  • Compare the precise brgemm config used in both cases (v1 v.s. mlir)
  • Perform more detailed performance breakdown

Please try brgemm instead of matmul, which can provide better performance result.

I dumped the final llvm IR, and verified that the current performance is collected with brgemm invoked. Previously when brgemm was not in effect, the performance is 10x worse. I think I need to do more detailed analysis to find where the performance gap exists exactly.

@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 12 times, most recently from f959a73 to ed5180d Compare July 27, 2024 02:09
@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 3 times, most recently from b3bf8dc to 23dfa97 Compare August 1, 2024 01:27
@yifeizh2 yifeizh2 changed the base branch from zhicong/deep_tile_matmul to main August 2, 2024 08:27
@yifeizh2
Copy link
Contributor Author

yifeizh2 commented Aug 2, 2024

Latest performance:

<style> </style>
SEQ LENGTH / DTYPE graph compiler v1 (ms) v2, block 64, brgemm invoked  
384 / fp32 8.28744 22.482 2.71
768 / fp32 39.1193 93.392 2.387
1536 / fp32 177.446 377.7458 2.128
2304 / fp32 389.382 810.249 2.080
3072 / fp32 682.228 1514.491 2.220

@yifeizh2
Copy link
Contributor Author

yifeizh2 commented Aug 2, 2024

Current observed gap from v1 are the following:

  • No vectorization
  • No fast transpose
  • No post op fusion
  • linalg.exp takes 1/3 of total execution time; needs to convert to optimized version

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Experimental] Scaled Dot Product Attention FlashAttention Algorithm Conversion
2 participants