-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] Count llvm instruction during conversion for scheduling hints #4819
[AMD] Count llvm instruction during conversion for scheduling hints #4819
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think overall this looks fine. But quite a few places we can simplify. Also need documentation and testing.
include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h
Outdated
Show resolved
Hide resolved
aece96b
to
06210a4
Compare
ad5a4e1
to
d32f444
Compare
|
||
op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { | ||
schedHint.setNumMMAsAttr(counterAttr); | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if this works when there are multiple tt.dot in the loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @zhanglx13,
No, it is not going to work. The multiple tt.dot
support would require further investigation and extensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you plan to generalize the design to support multiple tt.dot?
I'm asking because the pipelineV3 or CKV3 pipeline will prefetch the whole LDS buffer. However, the prefetchLDS pass can prefetch partial LDS buffer. But the prefetchLDS pass will lead to multiple tt.dot in the loop, each of which corresponds to one prefetched LDS sub-buffer.
The prefetchLDS pass will also need some sched_group_barrier tweak to "move things around".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I feel we may need to have more targeted instruction counting. The hint op is basically carrying side-channel information for the tt.dot
; we can have one hint op immediately before/after a tt.dot
for that tt.dot
. It's a bit fragile but fine if we insert it at the proper time. Then we may need to build different schedules for different tt.dot
ops (e.g., in main loop vs in epilogue or so). the instruction counting need to be more clever to figure out different "segments"..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Impl looks better now. Major missing pieces are still documentation and testing..
third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp
Outdated
Show resolved
Hide resolved
third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td
Outdated
Show resolved
Hide resolved
let arguments = (ins | ||
I32Attr:$numDsReadsTileA, | ||
I32Attr:$numDsReadsTileB, | ||
I32Attr:$numDsWritesTileA, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see thanks! You might want to put the link directly in the comment so it's easy to associate? (Right now what you have there is not a permlink.)
|
||
op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) { | ||
schedHint.setNumMMAsAttr(counterAttr); | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I feel we may need to have more targeted instruction counting. The hint op is basically carrying side-channel information for the tt.dot
; we can have one hint op immediately before/after a tt.dot
for that tt.dot
. It's a bit fragile but fine if we insert it at the proper time. Then we may need to build different schedules for different tt.dot
ops (e.g., in main loop vs in epilogue or so). the instruction counting need to be more clever to figure out different "segments"..
TritonAMDGPUToLLVM
passd861c01
to
cf97e35
Compare
cf97e35
to
ea01f4b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation looks good now. Just need to add tests next:
- Op tests for the new hint op
- Conversion tests for the pass
- etc.
include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h
Outdated
Show resolved
Hide resolved
BTW, @ravil-mobile, when you address comments, please use separate commits; don't squash everything into one commit--otherwise reviewers are required to reread everything. Separate commits allows us to only read the delta easily. Also prefer to |
2d9123e
to
84f6c1b
Compare
Agree. Makes sense |
bf3443b
to
32c91ac
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also fix the failing tests?
a06f1a7
to
cbbc694
Compare
cbbc694
to
23b5820
Compare
This commit relands #4819 with the following fixes: * Changed to a better way to mark opIdx for loads * Replaced temlate-based `rewindUnaryOps` to use regular for-loops. The new way is more robust and can handle other unary ops automatically. * Replaced `instr.sched.barriers` using the ones from `rocdl` dialect from the MLIR upstream * Extended lit tests
…riton-lang#4819) Advanced software pipelining may require fine-grained adjustments regarding instruction scheduling in the main `tt.dot` loop to achieve higher performance. Such adjustments require detailed information regarding the number of issued `v_mfma`, `ds_read`, `ds_write` and `global_load`, instructions. This PR extends the Triton AMDGPU backend by adding instruction counting during `TritonAMDGPUToLLVM` pass execution. An example of instruction counting and instruction scheduling is demonstrated in the `createCKV3Schedule` method which implements the [CK's V3 software pipelining](https://github.com/ROCm/composable_kernel/blob/de3e3b642402eac5b4a466f6a2fa5e9f022ba680/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp#L160-L263). This change is experimental for better GEMM performance. The design is not final and may subject to change in the future.
This commit relands triton-lang#4819 with the following fixes: * Changed to a better way to mark opIdx for loads * Replaced temlate-based `rewindUnaryOps` to use regular for-loops. The new way is more robust and can handle other unary ops automatically. * Replaced `instr.sched.barriers` using the ones from `rocdl` dialect from the MLIR upstream * Extended lit tests
…riton-lang#4819) Advanced software pipelining may require fine-grained adjustments regarding instruction scheduling in the main `tt.dot` loop to achieve higher performance. Such adjustments require detailed information regarding the number of issued `v_mfma`, `ds_read`, `ds_write` and `global_load`, instructions. This PR extends the Triton AMDGPU backend by adding instruction counting during `TritonAMDGPUToLLVM` pass execution. An example of instruction counting and instruction scheduling is demonstrated in the `createCKV3Schedule` method which implements the [CK's V3 software pipelining](https://github.com/ROCm/composable_kernel/blob/de3e3b642402eac5b4a466f6a2fa5e9f022ba680/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp#L160-L263). This change is experimental for better GEMM performance. The design is not final and may subject to change in the future.
This commit relands triton-lang#4819 with the following fixes: * Changed to a better way to mark opIdx for loads * Replaced temlate-based `rewindUnaryOps` to use regular for-loops. The new way is more robust and can handle other unary ops automatically. * Replaced `instr.sched.barriers` using the ones from `rocdl` dialect from the MLIR upstream * Extended lit tests
[AMD] Advanced software pipelining may require fine-grain adjustments regarding instruction scheduling in the main
tt.dot
loop to achieve higher performance. Such adjustments require detailed information regarding the number of issuedv_mfma
,ds_read
,ds_write
andglobal_load
, instructions. This PR extends the Triton AMDGPU backend by adding instruction counting duringTritonAMDGPUToLLVM
pass execution.An example of instruction counting and instruction scheduling is demonstrated in the
createCKV3Schedule
method which implements the CK's V3 software pipelining.