Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hoist reduction outside a loop #4559

Closed
wants to merge 1 commit into from

Conversation

binarman
Copy link
Contributor

@binarman binarman commented Aug 22, 2024

This PR introduces an optimization that hoists reduction operation of dot accumulator outside a loop over K dimension:

  %acc = <zero tensor>
  for k tiles:
    %acc3d_input = reshape %acc
    %acc3d_out = dot3d(%x, %y, %acc3d_input)
    %acc = reduction batch %acc3d_out

transforms to:

  %acc3d = <zero tensor>
  for k tiles:
    %acc3d = dot3d(%x, %y, %acc3d)
  %acc = reduction batch %acc3d

This PR is a part of PR series. Final goal is to improve efficiency of small dot operations and bypass as much shared memory accesses as possible.

Rough list of PRs:

This PR introduces an optimization that hoists reduction operation
of dot accumulator outside a loop over K dimension:

  %acc = <zero tensor>
  for k tiles:
    %acc3d_input = reshape %acc
    %acc3d_out = dot3d(%x, %y, %acc3d_input)
    %acc = reduction batch %acc3d_out

transforms to:

  %acc3d = <zero tensor>
  for k tiles:
    %acc3d = dot3d(%x, %y, %acc3d)
  %acc = reduction batch %acc3d
@ThomasRaoux
Copy link
Collaborator

why is this reduction in the loop in a first place? Is it because the compiler converts the dot in to a dot3d in a previous pass?

@binarman
Copy link
Contributor Author

@ThomasRaoux
Yes, I will make a PR with this dot2d->dot3d conversion next.

FMA dot requires all dot operand elements over K dimension in one thread, this limits parallelism in case of large K.
The idea is to split K dimension of dot3d in "batch" and K', compute pieces of K' in parallel and then reduce over batch dimension.

I want to keep this optimization in two parts for easier debug and simplicity.

@ThomasRaoux
Copy link
Collaborator

@ThomasRaoux Yes, I will make a PR with this dot2d->dot3d conversion next.

FMA dot requires all dot operand elements over K dimension in one thread, this limits parallelism in case of large K. The idea is to split K dimension of dot3d in "batch" and K', compute pieces of K' in parallel and then reduce over batch dimension.

I want to keep this optimization in two parts for easier debug and simplicity.

Split K is in the general the kind optimizations we want to let user control rather than doing in the compiler. Deciding when to do such transformation will be a hard heuristic and it is better to give user control so that they can have more intuitive performance.

@binarman binarman mentioned this pull request Aug 26, 2024
6 tasks
@binarman
Copy link
Contributor Author

binarman commented Aug 26, 2024

@ThomasRaoux

I've posted the optimization that I've mentioned. It is a transformation that generates dot3d->reduction pattern: #4581

Split K is in the general the kind optimizations we want to let user control rather than doing in the compiler.

Yes, but here I want to achieve specific pattern for 2d tl.dot without forcing user to reimplement existing code.

My motivation to implement it in compiler is following:

  • this transparency is explicitly requested from me
  • understanding the reasons why this split-k is needed requires layout knowledge from user, which significantly rises expected Triton internals understanding from users
  • user controlled approach needs new optimization from Triton side anyway

Let me show an example of what kind of dots I am targeting to be specific:

  1. dot operands are loaded from global memory
  2. dot look like this: (1x1024) x (1024x32) = (1x32), num_warps=8

In this case, M=1, N=32, K=1024.

There are two main bottlenecks:

  1. large B tensor requires a lot of operations to convert layouts in shared memory
  2. limited available parallelism in dot (FMA dot requires all elements over K dim to be held by one thread), so we have only M=1 and N=32 to fit 64 threads(for MI hardware) amd 8 warps

Here is the optimization ideas:

  1. bypass shared memory for second operand by loading data directly in compatible FMA dot operand layout.
  2. split k and distribute computations across threads on one warp. If we transform dot inside compiler we can easily adjust layouts so reduction is done only inside one warp.

Let's focus on split k patter, to ensure we are on the same page:

before optimization (we have 32 independent entities(M*N), and 512 threads (64 threads*8 warps))

%b = tt.load %bPtr : tensor<1024x32x!tt.ptr<i8>>
%bOp = triton_gpu.convert_layout %b : tensor<1024x32xi8, #blocked> -> tensor<1024x32xi8, #dot_operand_b>
%d = tt.dot %aOp, %bOp, %cOp : tensor<1x1024xi8, #dot_operand_a> * tensor<1024x32xi8, #dot_operand_b> -> tensor<1x32xi32, #blocked>

After optimization (blocked2 layout could be constructed with threadsPerWarp=[16, 1, 4], and waprsPerCTA=[1, 1, 8], this Triton can fully utilize available parallelism.)

%b = tt.load %bPtr : tensor<1024x32x!tt.ptr<i8>>
%b_batched = tt.reshape %b : tensor<1024x32xi8> -> tensor<16x64x32xi8>
%bOp = triton_gpu.convert_layout %b_batched : tensor<16x64x32xi8 #blocked> -> tensor<16x64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>>
%d_3d = tt.dot %aOp, %bOp, %cOp : tensor<16x1x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<16x64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<16x1x32xi32, #blocked2>
%red = "tt.reduce"(%d_3d) <{axis = 0 : i32}> ({
  ^bb0(%arg1: i32, %arg2: i32):
    %11 = arith.addi %arg1, %arg2 : i32
    tt.reduce.return %11 : i32
  }) : (tensor<16x1x32xi32>) -> tensor<1x32xi32, #triton_gpu.slice<{dim = 0}>>
%d = triton_gpu.convert_layout %red : tensor<1x32xi32, #triton_gpu.slice<{dim = 0}>> -> tensor<1x32xi32, #blocked>

I admit that my approach is not general at the moment, so I want to put it in AMD backend for now.

@ThomasRaoux
Copy link
Collaborator

@ThomasRaoux

I've posted the optimization that I've mentioned. It is a transformation that generates dot3d->reduction pattern: #4581

Split K is in the general the kind optimizations we want to let user control rather than doing in the compiler.

Yes, but here I want to achieve specific pattern for 2d tl.dot without forcing user to reimplement existing code.

I think this not a good reason to do it in the compiler for the reasons I mentioned before. This topic comes often as it is tempting to have the compiler rewrite the algorithm under the hood to show performance but this will hurt Triton's users in the long run as it gives them less control and will cause unpredictable performance.

My motivation to implement it in compiler is following:

  • this transparency is explicitly requested from me
  • understanding the reasons why this split-k is needed requires layout knowledge from user, which significantly rises expected Triton internals understanding from users

What part depends on layout? Are you saying this optimization can or cannot be done based on how layouts are picked?

  • user controlled approach needs new optimization from Triton side anyway

What does it need?

Let me show an example of what kind of dots I am targeting to be specific:

  1. dot operands are loaded from global memory
  2. dot look like this: (1x1024) x (1024x32) = (1x32), num_warps=8

In this case, M=1, N=32, K=1024.

There are two main bottlenecks:

  1. large B tensor requires a lot of operations to convert layouts in shared memory
  2. limited available parallelism in dot (FMA dot requires all elements over K dim to be held by one thread), so we have only M=1 and N=32 to fit 64 threads(for MI hardware) amd 8 warps

Here is the optimization ideas:

  1. bypass shared memory for second operand by loading data directly in compatible FMA dot operand layout.
  2. split k and distribute computations across threads on one warp. If we transform dot inside compiler we can easily adjust layouts so reduction is done only inside one warp.

Let's focus on split k patter, to ensure we are on the same page:

before optimization (we have 32 independent entities(MN), and 512 threads (64 threads8 warps))

%b = tt.load %bPtr : tensor<1024x32x!tt.ptr<i8>>
%bOp = triton_gpu.convert_layout %b : tensor<1024x32xi8, #blocked> -> tensor<1024x32xi8, #dot_operand_b>
%d = tt.dot %aOp, %bOp, %cOp : tensor<1x1024xi8, #dot_operand_a> * tensor<1024x32xi8, #dot_operand_b> -> tensor<1x32xi32, #blocked>

After optimization (blocked2 layout could be constructed with threadsPerWarp=[16, 1, 4], and waprsPerCTA=[1, 1, 8], this Triton can fully utilize available parallelism.)

%b = tt.load %bPtr : tensor<1024x32x!tt.ptr<i8>>
%b_batched = tt.reshape %b : tensor<1024x32xi8> -> tensor<16x64x32xi8>
%bOp = triton_gpu.convert_layout %b_batched : tensor<16x64x32xi8 #blocked> -> tensor<16x64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>>
%d_3d = tt.dot %aOp, %bOp, %cOp : tensor<16x1x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<16x64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<16x1x32xi32, #blocked2>
%red = "tt.reduce"(%d_3d) <{axis = 0 : i32}> ({
  ^bb0(%arg1: i32, %arg2: i32):
    %11 = arith.addi %arg1, %arg2 : i32
    tt.reduce.return %11 : i32
  }) : (tensor<16x1x32xi32>) -> tensor<1x32xi32, #triton_gpu.slice<{dim = 0}>>
%d = triton_gpu.convert_layout %red : tensor<1x32xi32, #triton_gpu.slice<{dim = 0}>> -> tensor<1x32xi32, #blocked>

The transformation makes sense but I'm not sure what would be the downside of having user explicitly writing code that way?

I admit that my approach is not general at the moment, so I want to put it in AMD backend for now.

Even if done in the AMD backend we need to keep the design of triton consistent and I don't think we want to start adding transformation rewriting the code in a different way based on heuristic that will be hard to get right.

@binarman
Copy link
Contributor Author

What part depends on layout? Are you saying this optimization can or cannot be done based on how layouts are picked?

Basically, yes. To decide if this transformation is beneficial, user should be aware of hardware he is going to use. For example, if dot can be executed with with matrix/tensor cores, this reduction trick can hurt performance. User should also consider warp size, which is not available in Triton language level. This could be resolved with autotuning, but this way we force user to implement two dot versions in one kernel.

The transformation makes sense but I'm not sure what would be the downside of having user explicitly writing code that way?

To make all this to work as intended (i.e. utilizing parallelism as much as possible and bypassing shared memory) Triton needs to emit compatible layouts all the way from load to dot , which is not controlled on Triton language level.
To adjust these layout we need new transformations or adjusting removeLayoutsConversions (which is already complex in my opinion)

Even if done in the AMD backend we need to keep the design of triton consistent and I don't think we want to start adding transformation rewriting the code in a different way based on heuristic that will be hard to get right.

I see, let me discuss this internally.

@binarman
Copy link
Contributor Author

Closing this PR for now. Will reopen it if base PRs (for example #4516 ) are merged decision on automatic optimization is made.

@binarman binarman closed this Nov 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants