Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Count llvm instruction during conversion for scheduling hints #4819

Merged
merged 5 commits into from
Oct 13, 2024

Conversation

ravil-mobile
Copy link
Contributor

@ravil-mobile ravil-mobile commented Sep 27, 2024

[AMD] Advanced software pipelining may require fine-grain adjustments regarding instruction scheduling in the main tt.dot loop to achieve higher performance. Such adjustments require detailed information regarding the number of issued v_mfma, ds_read, ds_write and global_load, instructions. This PR extends the Triton AMDGPU backend by adding instruction counting during TritonAMDGPUToLLVM pass execution.

An example of instruction counting and instruction scheduling is demonstrated in the createCKV3Schedule method which implements the CK's V3 software pipelining.

Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall this looks fine. But quite a few places we can simplify. Also need documentation and testing.

lib/Conversion/TritonGPUToLLVM/Utility.cpp Outdated Show resolved Hide resolved
lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp Outdated Show resolved Hide resolved
lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp Outdated Show resolved Hide resolved
@ravil-mobile ravil-mobile force-pushed the ravil/sched-barriers-stat branch 7 times, most recently from aece96b to 06210a4 Compare October 1, 2024 15:32

op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) {
schedHint.setNumMMAsAttr(counterAttr);
});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this works when there are multiple tt.dot in the loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @zhanglx13,

No, it is not going to work. The multiple tt.dot support would require further investigation and extensions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to generalize the design to support multiple tt.dot?
I'm asking because the pipelineV3 or CKV3 pipeline will prefetch the whole LDS buffer. However, the prefetchLDS pass can prefetch partial LDS buffer. But the prefetchLDS pass will lead to multiple tt.dot in the loop, each of which corresponds to one prefetched LDS sub-buffer.
The prefetchLDS pass will also need some sched_group_barrier tweak to "move things around".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I feel we may need to have more targeted instruction counting. The hint op is basically carrying side-channel information for the tt.dot; we can have one hint op immediately before/after a tt.dot for that tt.dot. It's a bit fragile but fine if we insert it at the proper time. Then we may need to build different schedules for different tt.dot ops (e.g., in main loop vs in epilogue or so). the instruction counting need to be more clever to figure out different "segments"..

Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Impl looks better now. Major missing pieces are still documentation and testing..

let arguments = (ins
I32Attr:$numDsReadsTileA,
I32Attr:$numDsReadsTileB,
I32Attr:$numDsWritesTileA,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thanks! You might want to put the link directly in the comment so it's easy to associate? (Right now what you have there is not a permlink.)

https://github.com/ROCm/composable_kernel/blob/de3e3b642402eac5b4a466f6a2fa5e9f022ba680/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp#L160-L263


op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) {
schedHint.setNumMMAsAttr(counterAttr);
});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I feel we may need to have more targeted instruction counting. The hint op is basically carrying side-channel information for the tt.dot; we can have one hint op immediately before/after a tt.dot for that tt.dot. It's a bit fragile but fine if we insert it at the proper time. Then we may need to build different schedules for different tt.dot ops (e.g., in main loop vs in epilogue or so). the instruction counting need to be more clever to figure out different "segments"..

@antiagainst antiagainst changed the title [AMD] instruction counting during TritonAMDGPUToLLVM pass [AMD] Count llvm instruction during conversion for scheduling hints Oct 3, 2024
@ravil-mobile ravil-mobile force-pushed the ravil/sched-barriers-stat branch 9 times, most recently from d861c01 to cf97e35 Compare October 4, 2024 14:18
@ravil-mobile ravil-mobile marked this pull request as ready for review October 4, 2024 16:46
Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation looks good now. Just need to add tests next:

  • Op tests for the new hint op
  • Conversion tests for the pass
  • etc.

@antiagainst
Copy link
Collaborator

BTW, @ravil-mobile, when you address comments, please use separate commits; don't squash everything into one commit--otherwise reviewers are required to reread everything. Separate commits allows us to only read the delta easily. Also prefer to git merge origin/main to force push--it also helps speed up code reviews. Thanks! :)

@ravil-mobile
Copy link
Contributor Author

BTW, @ravil-mobile, when you address comments, please use separate commits; don't squash everything into one commit--otherwise reviewers are required to reread everything. Separate commits allows us to only read the delta easily. Also prefer to git merge origin/main to force push--it also helps speed up code reviews. Thanks! :)

Agree. Makes sense

@ravil-mobile ravil-mobile force-pushed the ravil/sched-barriers-stat branch 2 times, most recently from bf3443b to 32c91ac Compare October 7, 2024 16:15
Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also fix the failing tests?

test/TritonGPU/amd/amd-instruction-sched.mlir Outdated Show resolved Hide resolved
test/TritonGPU/amd/amd-instruction-sched.mlir Outdated Show resolved Hide resolved
test/TritonGPU/amd/amd-instruction-sched.mlir Outdated Show resolved Hide resolved
test/TritonGPU/amd/amd-instruction-sched.mlir Show resolved Hide resolved
@ravil-mobile ravil-mobile force-pushed the ravil/sched-barriers-stat branch 6 times, most recently from a06f1a7 to cbbc694 Compare October 10, 2024 15:09
@antiagainst antiagainst merged commit e87f877 into triton-lang:main Oct 13, 2024
7 checks passed
ptillet added a commit that referenced this pull request Oct 16, 2024
ptillet added a commit that referenced this pull request Oct 16, 2024
antiagainst pushed a commit that referenced this pull request Oct 31, 2024
This commit relands #4819
with the following fixes:

* Changed to a better way to mark opIdx for loads
* Replaced temlate-based `rewindUnaryOps` to use regular
  for-loops. The new way is more robust and can handle other
  unary ops automatically.
* Replaced `instr.sched.barriers` using the ones from
  `rocdl` dialect from the MLIR upstream
* Extended lit tests
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
…riton-lang#4819)

Advanced software pipelining may require fine-grained adjustments
regarding instruction scheduling in the main `tt.dot` loop to achieve
higher performance. Such adjustments require detailed information
regarding the number of issued `v_mfma`, `ds_read`, `ds_write` and
`global_load`, instructions. This PR extends the Triton AMDGPU backend
by adding instruction counting during `TritonAMDGPUToLLVM` pass
execution.

An example of instruction counting and instruction scheduling is
demonstrated in the `createCKV3Schedule` method which implements the
[CK's V3 software
pipelining](https://github.com/ROCm/composable_kernel/blob/de3e3b642402eac5b4a466f6a2fa5e9f022ba680/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp#L160-L263).

This change is experimental for better GEMM performance. The design
is not final and may subject to change in the future.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
This commit relands triton-lang#4819
with the following fixes:

* Changed to a better way to mark opIdx for loads
* Replaced temlate-based `rewindUnaryOps` to use regular
  for-loops. The new way is more robust and can handle other
  unary ops automatically.
* Replaced `instr.sched.barriers` using the ones from
  `rocdl` dialect from the MLIR upstream
* Extended lit tests
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
…riton-lang#4819)

Advanced software pipelining may require fine-grained adjustments
regarding instruction scheduling in the main `tt.dot` loop to achieve
higher performance. Such adjustments require detailed information
regarding the number of issued `v_mfma`, `ds_read`, `ds_write` and
`global_load`, instructions. This PR extends the Triton AMDGPU backend
by adding instruction counting during `TritonAMDGPUToLLVM` pass
execution.

An example of instruction counting and instruction scheduling is
demonstrated in the `createCKV3Schedule` method which implements the
[CK's V3 software
pipelining](https://github.com/ROCm/composable_kernel/blob/de3e3b642402eac5b4a466f6a2fa5e9f022ba680/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp#L160-L263).

This change is experimental for better GEMM performance. The design
is not final and may subject to change in the future.
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
This commit relands triton-lang#4819
with the following fixes:

* Changed to a better way to mark opIdx for loads
* Replaced temlate-based `rewindUnaryOps` to use regular
  for-loops. The new way is more robust and can handle other
  unary ops automatically.
* Replaced `instr.sched.barriers` using the ones from
  `rocdl` dialect from the MLIR upstream
* Extended lit tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants