-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hoist reduction outside a loop #4559
Conversation
This PR introduces an optimization that hoists reduction operation of dot accumulator outside a loop over K dimension: %acc = <zero tensor> for k tiles: %acc3d_input = reshape %acc %acc3d_out = dot3d(%x, %y, %acc3d_input) %acc = reduction batch %acc3d_out transforms to: %acc3d = <zero tensor> for k tiles: %acc3d = dot3d(%x, %y, %acc3d) %acc = reduction batch %acc3d
why is this reduction in the loop in a first place? Is it because the compiler converts the dot in to a dot3d in a previous pass? |
@ThomasRaoux FMA dot requires all dot operand elements over K dimension in one thread, this limits parallelism in case of large K. I want to keep this optimization in two parts for easier debug and simplicity. |
Split K is in the general the kind optimizations we want to let user control rather than doing in the compiler. Deciding when to do such transformation will be a hard heuristic and it is better to give user control so that they can have more intuitive performance. |
I've posted the optimization that I've mentioned. It is a transformation that generates dot3d->reduction pattern: #4581
Yes, but here I want to achieve specific pattern for 2d tl.dot without forcing user to reimplement existing code. My motivation to implement it in compiler is following:
Let me show an example of what kind of dots I am targeting to be specific:
In this case, M=1, N=32, K=1024. There are two main bottlenecks:
Here is the optimization ideas:
Let's focus on split k patter, to ensure we are on the same page: before optimization (we have 32 independent entities(M*N), and 512 threads (64 threads*8 warps))
After optimization (blocked2 layout could be constructed with threadsPerWarp=[16, 1, 4], and waprsPerCTA=[1, 1, 8], this Triton can fully utilize available parallelism.)
I admit that my approach is not general at the moment, so I want to put it in AMD backend for now. |
I think this not a good reason to do it in the compiler for the reasons I mentioned before. This topic comes often as it is tempting to have the compiler rewrite the algorithm under the hood to show performance but this will hurt Triton's users in the long run as it gives them less control and will cause unpredictable performance.
What part depends on layout? Are you saying this optimization can or cannot be done based on how layouts are picked?
What does it need?
The transformation makes sense but I'm not sure what would be the downside of having user explicitly writing code that way?
Even if done in the AMD backend we need to keep the design of triton consistent and I don't think we want to start adding transformation rewriting the code in a different way based on heuristic that will be hard to get right. |
Basically, yes. To decide if this transformation is beneficial, user should be aware of hardware he is going to use. For example, if dot can be executed with with matrix/tensor cores, this reduction trick can hurt performance. User should also consider warp size, which is not available in Triton language level. This could be resolved with autotuning, but this way we force user to implement two dot versions in one kernel.
To make all this to work as intended (i.e. utilizing parallelism as much as possible and bypassing shared memory) Triton needs to emit compatible layouts all the way from load to dot , which is not controlled on Triton language level.
I see, let me discuss this internally. |
Closing this PR for now. Will reopen it if base PRs (for example #4516 ) are merged decision on automatic optimization is made. |
This PR introduces an optimization that hoists reduction operation of dot accumulator outside a loop over K dimension:
transforms to:
This PR is a part of PR series. Final goal is to improve efficiency of small dot operations and bypass as much shared memory accesses as possible.
Rough list of PRs: