Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve GEMM perf when one matrix is transposed #2347

Merged
merged 4 commits into from
Sep 27, 2024

Conversation

alexbaden
Copy link
Contributor

The 2D block load/store does not work when one of the input matrices to a tt.dot is transposed inside the Triton kernel using the stride parameter. In the user example, the block pointer is transposed using stride but the order parameter is left unchanged. This results in materialize-block-pointer being unable to detect that a block_io attribute column-major should be added to the matrix. Even if this attribute were added, rewrite-tensor-pointer would remove the block pointer because column major was not supported.

This PR adds support for detecting column-major based on stride instead of order and also brings the same logic to rewrite-tensor-pointer to allow for the column major load to be preserved and eventually lowered to a 2D block load. With this, transpose matrix performance is more inline with the non-transposed version:

Compute A x B
(I): Detected 7680 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
✅ Triton and Torch match
Time for torch: 0.31821921467781067 ms
Time for triton: 0.4404735863208771 ms
Compute A x B.T
(I): Detected 7680 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
✅ Triton and Torch match
Time for torch: 0.33270877599716187 ms
Time for triton: 0.6352895498275757 ms

I know we have plans to remove rewrite-tensor-pointer eventually, but column major support did not appear difficult, and the performance is nearly 3x better.

I am planning to PR each commit individually starting with the debug logging, but wanted to open this umbrella PR to show how the entire pipeline fits together.

Close #1795

@chengjunlu
Copy link
Contributor

@alexbaden
I have created an issue to track the performance gap of the transpose and non-transpose B cases.
#2354

alexbaden added a commit that referenced this pull request Sep 26, 2024
Follows the Triton llvm debug syntax. Allows you to dump various
parameters when running `triton-opt` with `-debug`.

cc #2347
alexbaden added a commit that referenced this pull request Sep 26, 2024
Per the Triton slack, `order` is unused on architecture below Hopper.
But more importantly, order provides information that stride already
has. In fact, order can be completely different from stride (i.e. wrong)
and we still generate correct code. I think it is better to use the
stride assuming the logic I added here makes sense.

Note this depends on #2348, I'd like to land the debug logging
separately, so we have it even if we decide to modify this approach. It
was very useful in debugging this problem.

cc #2347
@alexbaden alexbaden force-pushed the alex/matmul_transpose_performance branch from 63cb029 to 835e8a0 Compare September 26, 2024 01:50
Copy link
Contributor

@whitneywhtsang whitneywhtsang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preference is to remove RewriteTensorPointer pass after #2181 is landed.
Would that work for your test case?

@alexbaden
Copy link
Contributor Author

Yes, this change is only to keep rewrite tensor pointer from removing the blocked load.

@alexbaden alexbaden force-pushed the alex/matmul_transpose_performance branch from 835e8a0 to 4705bde Compare September 26, 2024 15:36
@alexbaden
Copy link
Contributor Author

Merging #2181 resolved the fused attention issue because the transposed load is lowered to a gather load instead of failing to lower to 2D blocked load.

@alexbaden alexbaden changed the title [umbrella] Improve GEMM perf when one matrix is transposed Improve GEMM perf when one matrix is transposed Sep 26, 2024
@etiotto
Copy link
Contributor

etiotto commented Sep 26, 2024

I think that removing the RewriteTensorPointer pass as in PR #2359 is the way to go. We want to avoid rewriting block pointers into regular pointers to maximize the number of tt.load/tt.store operations that can be lowered to 2D block/read/write instructions, so removing that pass seems the right approach to me.

If we land PR #2359 this PR can be rebased presumably and the LIT test kept.

@etiotto etiotto self-requested a review September 26, 2024 16:47
@alexbaden
Copy link
Contributor Author

@chengjunlu @whitneywhtsang I modified the RewriteTensorPointer pass to remove make_tensor_ptr fpr the FP8 column major DPAS/dot case. This resolves the fp8 regression until we can vectorize the gathered loads. I also cleaned up the code a bit to remove a section that was redundant and make the semantics match MaterializeBlockPointer when checking the pitch. Please take a look.

06-fused-attention performance:
this PR

fused-attention-batch4-head32-d64-fwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0      32.843361     14.780091
1   2048.0      42.282882     17.905053
2   4096.0      48.726973     20.021225
3   8192.0      53.396336     21.004981
4  16384.0      55.262615     21.451676
fused-attention-batch4-head32-d64-fwd-causal=False:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0      42.644897     18.312834
1   2048.0      45.299192     18.500770
2   4096.0      47.940167     18.471753
3   8192.0      53.223156     17.411309
4  16384.0      55.170334     17.677018
fused-attention-batch4-head32-d64-bwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0       2.283084      2.283702
1   2048.0       2.706487      2.645186
2   4096.0       2.955933      2.965194
3   8192.0       3.108108      3.094504
4  16384.0       3.163420      3.166231

main:

fused-attention-batch4-head32-d64-fwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0      22.285431     14.742289
1   2048.0      28.124196     17.888429
2   4096.0      31.669878     20.006270
3   8192.0      33.335991     21.019115
4  16384.0      33.837792     21.324367
fused-attention-batch4-head32-d64-fwd-causal=False:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0      29.751250     18.358382
1   2048.0      32.622166     18.560691
2   4096.0      34.748520     18.097661
3   8192.0      35.116679     17.155187
4  16384.0      35.352343     17.631020
fused-attention-batch4-head32-d64-bwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0       2.285378      2.285584
1   2048.0       2.694824      2.687230
2   4096.0       2.968936      2.957643
3   8192.0       3.094857      3.096270
4  16384.0       3.166553      3.164195

matrix performance is unchanged

Copy link
Contributor

@whitneywhtsang whitneywhtsang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking that we could check the blockio attribute of the users of the MakeTensorPtrOp, that way we don't need to maintain a copy of the same condition checks in RewriteTensorPointer. WDYT?

@alexbaden
Copy link
Contributor Author

Thinking that we could check the blockio attribute of the users of the MakeTensorPtrOp, that way we don't need to maintain a copy of the same condition checks in RewriteTensorPointer. WDYT?

Interesting - I can try it. But let's make it a separate PR. I want to show some progress on the original transpose issue, and I am not 100% confident that I understand the logic well enough to handle all cases, so it will need some testing/review.

@whitneywhtsang
Copy link
Contributor

Thinking that we could check the blockio attribute of the users of the MakeTensorPtrOp, that way we don't need to maintain a copy of the same condition checks in RewriteTensorPointer. WDYT?

Interesting - I can try it. But let's make it a separate PR. I want to show some progress on the original transpose issue, and I am not 100% confident that I understand the logic well enough to handle all cases, so it will need some testing/review.

Sure, let's do that in a separate PR.

@alexbaden alexbaden merged commit c428109 into main Sep 27, 2024
4 checks passed
@alexbaden alexbaden deleted the alex/matmul_transpose_performance branch September 27, 2024 15:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[GEMM-perf] matmul is slower when one input needs to be transposed
4 participants