Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[intel] Remove RewriteTensorPointer pass #2359

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

whitneywhtsang
Copy link
Contributor

After #2181, tt.load can be lowered with arbitrary combination of the block pointer and layout, so we can simply remove the RewriteTensorPointer pass.

@alexbaden
Copy link
Contributor

In principle this looks ok - but it is a pretty big divergence from upstream. I understand we want to propagate the tensor pointer as long as possible so we can lower it to 2D block loads if possible. But, if the 2D block load is not possible, do we lose the possibility for optimization of the unpacked load in the TTGIR?

The other disadvantage is for debugging - now the Triton Intel GPU to LLVM pass does even more work, and its very hard to debug individual pieces of that pass vs if we could represent TritonGEN::Matrix2BlockLoads in the ttgir.

Signed-off-by: Whitney Tsang <[email protected]>
@whitneywhtsang
Copy link
Contributor Author

In principle this looks ok - but it is a pretty big divergence from upstream. I understand we want to propagate the tensor pointer as long as possible so we can lower it to 2D block loads if possible. But, if the 2D block load is not possible, do we lose the possibility for optimization of the unpacked load in the TTGIR?

I understand your worries. IMO, we should solve this problem in general, even for blocked pointer that can be lowered to 2D block loads. One idea is to introduce a new interface upstream, and modify optimization passes to operate on the interface, so they can work for both tensor of pointers and blocked pointers.

The other disadvantage is for debugging - now the Triton Intel GPU to LLVM pass does even more work, and its very hard to debug individual pieces of that pass vs if we could represent TritonGEN::Matrix2BlockLoads in the ttgir.

I am not sure I completely understand your idea, as TritonGEN::Matrix2DBlockLoad is only added in the conversion from TritonGPU to LLVM pass, removing RewriteTensorPointer pass or not. However, in general, I agree that ConvertTritonGPUToLLVM pass is doing too much, and splitting it into more passes should help debugging.

Another motivation to remove RewriteTensorPointer pass is to improve the code generated, it was identified that the code generated by RewriteTensorPointer is not optimal: #1766.

@whitneywhtsang
Copy link
Contributor Author

Baseline:

fused-attention-batch4-head32-d64-fwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0      22.549495     14.749113
1   2048.0      28.547367     17.921739
2   4096.0      32.146597     19.973736
3   8192.0      32.947761     20.989199
4  16384.0      33.405508     21.315338
fused-attention-batch4-head32-d64-fwd-causal=False:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0      30.001837     18.370397
1   2048.0      32.612706     18.752953
2   4096.0      34.366668     18.195319
3   8192.0      35.105017     17.859327
4  16384.0      34.337847     18.488528

Remove RewriteTensorPointer:

fused-attention-batch4-head32-d64-fwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0      30.936479      7.317689
1   2048.0      40.855545      9.091529
2   4096.0      47.497647     10.083325
3   8192.0      51.679951     10.775418
4  16384.0      52.747427     11.146273
fused-attention-batch4-head32-d64-fwd-causal=False:
     N_CTX  Triton [FP16]  Triton [FP8]
0   1024.0      42.115509     10.735080
1   2048.0      44.385075     11.132770
2   4096.0      46.446279     11.850412
3   8192.0      51.499184     12.571295
4  16384.0      53.933902     12.602635

=> 06-fused-attention.py performance degraded for FP8.

We need to improve tt.load and tt.store lowering that convert blocked pointer to tensor of pointer before removing RewriteTensorPointer.

@etiotto
Copy link
Contributor

etiotto commented Sep 27, 2024

We need to improve tt.load and tt.store lowering that convert blocked pointer to tensor of pointer before removing RewriteTensorPointer.

Yup. Is this something you plan to work on in this PR ? If not we can put this PR in draft mode (until that piece of work is done).

@whitneywhtsang whitneywhtsang marked this pull request as draft September 27, 2024 15:02
@whitneywhtsang
Copy link
Contributor Author

We need to improve tt.load and tt.store lowering that convert blocked pointer to tensor of pointer before removing RewriteTensorPointer.

Yup. Is this something you plan to work on in this PR ? If not we can put this PR in draft mode (until that piece of work is done).

Moved to draft.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants