Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform][Tiling] Add deep tile support for matmul #90

Merged
merged 21 commits into from
Aug 9, 2024

Conversation

zhczhong
Copy link
Member

@zhczhong zhczhong commented May 20, 2024

Tracking #53

TODO:

  • the nested outer loop generation
  • partial reduction support
    • enhance the PartialReductionOpInterface to allow user control where the new parallel dims are inserted
    • Erase the reducant linalg.FillOp in partial reduction
  • merge all parallel iterator into a single scf.forall before nested parallel is ready
  • fuse the linalg.fillOp into the innermost loop body
  • replace all genericOp with linalg named op
  • Support 4Dx4/5D->4D, 2Dx2D->2D, 2Dx4/5D->2D
  • Dtype Support(f32, bf16)
  • Fuse the f32->bf16 cast into the last loop about K axis
  • Support Batch matmul
  • Balance211 support
  • Tune a general matmul config based on cost model
  • Fuse the linalg.copy to the innermost loop

@zhczhong zhczhong added the WIP work in progress label May 20, 2024
@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 3 times, most recently from 7c8cfbb to 927322a Compare May 23, 2024 06:11
@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 6 times, most recently from ea02416 to f261c3c Compare June 3, 2024 03:47
@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 5 times, most recently from 5ed4fc1 to 22d86d4 Compare June 5, 2024 03:21
@zhczhong
Copy link
Member Author

zhczhong commented Jun 5, 2024

Support use linalgx.batch_reduce_vnni(bf16xbf16->f32) and fuse the cast(f32->bf16) to the last loop about K axis

func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
    %cst_0 = arith.constant 0.000000e+00 : bf16
    %0 = tensor.empty() : tensor<128x128x32x32xbf16>
    %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16>
    %2 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>)  -> tensor<128x128x32x32xbf16>
    return %2 : tensor<128x128x32x32xbf16>
}

will be transformed into

#map = affine_map<(d0) -> (d0 * 64)>
#map1 = affine_map<(d0)[s0, s1] -> (d0 * 64 + s0 + s1)>
module {
  func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
    %c1 = arith.constant 1 : index
    %c128 = arith.constant 128 : index
    %c2 = arith.constant 2 : index
    %c64 = arith.constant 64 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : bf16
    %0 = tensor.empty() : tensor<128x128x32x32xbf16>
    %1 = scf.forall (%arg2, %arg3) in (2, 2) shared_outs(%arg4 = %0) -> (tensor<128x128x32x32xbf16>) {
      %2 = affine.apply #map(%arg2)
      %3 = affine.apply #map(%arg3)
      %extracted_slice = tensor.extract_slice %arg4[%2, %3, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<64x64x32x32xbf16>
      %4 = scf.for %arg5 = %c0 to %c64 step %c2 iter_args(%arg6 = %extracted_slice) -> (tensor<64x64x32x32xbf16>) {
        %extracted_slice_0 = tensor.extract_slice %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> to tensor<2x64x32x32xbf16>
        %7 = scf.for %arg7 = %c0 to %c64 step %c2 iter_args(%arg8 = %extracted_slice_0) -> (tensor<2x64x32x32xbf16>) {
          %extracted_slice_1 = tensor.extract_slice %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> to tensor<2x2x32x32xbf16>
          %8 = tensor.empty() : tensor<2x2x32x32xf32>
          %9 = scf.for %arg9 = %c0 to %c128 step %c2 iter_args(%arg10 = %8) -> (tensor<2x2x32x32xf32>) {
            %11 = scf.for %arg11 = %c0 to %c2 step %c1 iter_args(%arg12 = %arg10) -> (tensor<2x2x32x32xf32>) {
              %extracted_slice_3 = tensor.extract_slice %arg12[%arg11, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xf32> to tensor<1x2x32x32xf32>
              %12 = scf.for %arg13 = %c0 to %c2 step %c1 iter_args(%arg14 = %extracted_slice_3) -> (tensor<1x2x32x32xf32>) {
                %13 = affine.apply #map1(%arg2)[%arg11, %arg5]
                %extracted_slice_5 = tensor.extract_slice %arg0[%13, %arg9, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<2x32x32xbf16>
                %14 = affine.apply #map1(%arg3)[%arg13, %arg7]
                %extracted_slice_6 = tensor.extract_slice %arg1[%14, %arg9, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<128x128x16x32x2xbf16> to tensor<2x16x32x2xbf16>
                %extracted_slice_7 = tensor.extract_slice %arg14[0, %arg13, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> to tensor<32x32xf32>
                %15 = arith.cmpi eq, %arg9, %c0 : index
                %16 = scf.if %15 -> (tensor<32x32xf32>) {
                  %17 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_7 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  %18 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_5, %extracted_slice_6 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%17 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  scf.yield %18 : tensor<32x32xf32>
                } else {
                  %17 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_5, %extracted_slice_6 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%extracted_slice_7 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  scf.yield %17 : tensor<32x32xf32>
                }
                %inserted_slice_8 = tensor.insert_slice %16 into %arg14[0, %arg13, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<1x2x32x32xf32>
                scf.yield %inserted_slice_8 : tensor<1x2x32x32xf32>
              }
              %inserted_slice_4 = tensor.insert_slice %12 into %arg12[%arg11, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> into tensor<2x2x32x32xf32>
              scf.yield %inserted_slice_4 : tensor<2x2x32x32xf32>
            }
            scf.yield %11 : tensor<2x2x32x32xf32>
          }
          %10 = linalg.copy ins(%9 : tensor<2x2x32x32xf32>) outs(%extracted_slice_1 : tensor<2x2x32x32xbf16>) -> tensor<2x2x32x32xbf16>
          %inserted_slice_2 = tensor.insert_slice %10 into %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xbf16> into tensor<2x64x32x32xbf16>
          scf.yield %inserted_slice_2 : tensor<2x64x32x32xbf16>
        }
        %inserted_slice = tensor.insert_slice %7 into %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> into tensor<64x64x32x32xbf16>
        scf.yield %inserted_slice : tensor<64x64x32x32xbf16>
      }
      %5 = affine.apply #map(%arg2)
      %6 = affine.apply #map(%arg3)
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %4 into %arg4[%5, %6, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> into tensor<128x128x32x32xbf16>
      }
    }
    return %1 : tensor<128x128x32x32xbf16>
  }
}

@zhczhong
Copy link
Member Author

Update: Fuse the cast(f32->bf16) to the innermost loop

func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
    %cst_0 = arith.constant 0.000000e+00 : bf16
    %0 = tensor.empty() : tensor<128x128x32x32xbf16>
    %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<128x128x32x32xbf16>) -> tensor<128x128x32x32xbf16>
    %2 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<128x128x32x32xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<128x128x32x32xbf16>)  -> tensor<128x128x32x32xbf16>
    return %2 : tensor<128x128x32x32xbf16>
}

will be transformed to

#map = affine_map<(d0) -> (d0 * 64)>
#map1 = affine_map<(d0)[s0, s1] -> (d0 * 64 + s0 + s1)>
module {
  func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<128x128x32x32xbf16> {
    %c1 = arith.constant 1 : index
    %c128 = arith.constant 128 : index
    %c2 = arith.constant 2 : index
    %c64 = arith.constant 64 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : bf16
    %0 = tensor.empty() : tensor<128x128x32x32xbf16>
    %1 = scf.forall (%arg2, %arg3) in (2, 2) shared_outs(%arg4 = %0) -> (tensor<128x128x32x32xbf16>) {
      %2 = affine.apply #map(%arg2)
      %3 = affine.apply #map(%arg3)
      %extracted_slice = tensor.extract_slice %arg4[%2, %3, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<64x64x32x32xbf16>
      %4 = scf.for %arg5 = %c0 to %c64 step %c2 iter_args(%arg6 = %extracted_slice) -> (tensor<64x64x32x32xbf16>) {
        %extracted_slice_0 = tensor.extract_slice %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> to tensor<2x64x32x32xbf16>
        %7 = scf.for %arg7 = %c0 to %c64 step %c2 iter_args(%arg8 = %extracted_slice_0) -> (tensor<2x64x32x32xbf16>) {
          %extracted_slice_1 = tensor.extract_slice %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> to tensor<2x2x32x32xbf16>
          %8 = tensor.empty() : tensor<2x2x32x32xf32>
          %9:2 = scf.for %arg9 = %c0 to %c128 step %c2 iter_args(%arg10 = %8, %arg11 = %extracted_slice_1) -> (tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>) {
            %10:2 = scf.for %arg12 = %c0 to %c2 step %c1 iter_args(%arg13 = %arg10, %arg14 = %arg11) -> (tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>) {
              %extracted_slice_3 = tensor.extract_slice %arg13[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xf32> to tensor<1x2x32x32xf32>
              %extracted_slice_4 = tensor.extract_slice %arg14[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xbf16> to tensor<1x2x32x32xbf16>
              %11:2 = scf.for %arg15 = %c0 to %c2 step %c1 iter_args(%arg16 = %extracted_slice_3, %arg17 = %extracted_slice_4) -> (tensor<1x2x32x32xf32>, tensor<1x2x32x32xbf16>) {
                %12 = affine.apply #map1(%arg2)[%arg12, %arg5]
                %extracted_slice_7 = tensor.extract_slice %arg0[%12, %arg9, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<128x128x32x32xbf16> to tensor<2x32x32xbf16>
                %13 = affine.apply #map1(%arg3)[%arg15, %arg7]
                %extracted_slice_8 = tensor.extract_slice %arg1[%13, %arg9, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<128x128x16x32x2xbf16> to tensor<2x16x32x2xbf16>
                %extracted_slice_9 = tensor.extract_slice %arg16[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> to tensor<32x32xf32>
                %extracted_slice_10 = tensor.extract_slice %arg17[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xbf16> to tensor<32x32xbf16>
                %14 = arith.cmpi eq, %arg9, %c0 : index
                %15 = scf.if %14 -> (tensor<32x32xf32>) {
                  %18 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_9 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  %19 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_7, %extracted_slice_8 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%18 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  scf.yield %19 : tensor<32x32xf32>
                } else {
                  %18 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_7, %extracted_slice_8 : tensor<2x32x32xbf16>, tensor<2x16x32x2xbf16>) outs(%extracted_slice_9 : tensor<32x32xf32>) -> tensor<32x32xf32>
                  scf.yield %18 : tensor<32x32xf32>
                }
                %16 = arith.cmpi eq, %arg9, %c0 : index
                %17 = scf.if %16 -> (tensor<32x32xbf16>) {
                  %18 = linalg.copy ins(%15 : tensor<32x32xf32>) outs(%extracted_slice_10 : tensor<32x32xbf16>) -> tensor<32x32xbf16>
                  scf.yield %18 : tensor<32x32xbf16>
                } else {
                  scf.yield %extracted_slice_10 : tensor<32x32xbf16>
                }
                %inserted_slice_11 = tensor.insert_slice %15 into %arg16[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<1x2x32x32xf32>
                %inserted_slice_12 = tensor.insert_slice %17 into %arg17[0, %arg15, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xbf16> into tensor<1x2x32x32xbf16>
                scf.yield %inserted_slice_11, %inserted_slice_12 : tensor<1x2x32x32xf32>, tensor<1x2x32x32xbf16>
              }
              %inserted_slice_5 = tensor.insert_slice %11#0 into %arg13[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xf32> into tensor<2x2x32x32xf32>
              %inserted_slice_6 = tensor.insert_slice %11#1 into %arg14[%arg12, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : tensor<1x2x32x32xbf16> into tensor<2x2x32x32xbf16>
              scf.yield %inserted_slice_5, %inserted_slice_6 : tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>
            }
            scf.yield %10#0, %10#1 : tensor<2x2x32x32xf32>, tensor<2x2x32x32xbf16>
          }
          %inserted_slice_2 = tensor.insert_slice %9#1 into %arg8[0, %arg7, 0, 0] [2, 2, 32, 32] [1, 1, 1, 1] : tensor<2x2x32x32xbf16> into tensor<2x64x32x32xbf16>
          scf.yield %inserted_slice_2 : tensor<2x64x32x32xbf16>
        }
        %inserted_slice = tensor.insert_slice %7 into %arg6[%arg5, 0, 0, 0] [2, 64, 32, 32] [1, 1, 1, 1] : tensor<2x64x32x32xbf16> into tensor<64x64x32x32xbf16>
        scf.yield %inserted_slice : tensor<64x64x32x32xbf16>
      }
      %5 = affine.apply #map(%arg2)
      %6 = affine.apply #map(%arg3)
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %4 into %arg4[%5, %6, 0, 0] [64, 64, 32, 32] [1, 1, 1, 1] : tensor<64x64x32x32xbf16> into tensor<128x128x32x32xbf16>
      }
    }
    return %1 : tensor<128x128x32x32xbf16>
  }
}

@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 4 times, most recently from d69856f to 823be69 Compare July 2, 2024 02:53
@zhczhong zhczhong linked an issue Jul 10, 2024 that may be closed by this pull request
@zhczhong zhczhong force-pushed the zhicong/deep_tile_matmul branch 3 times, most recently from 304dcde to 9dce4b3 Compare August 7, 2024 03:17
@ZhennanQin ZhennanQin merged commit 8948c6b into main Aug 9, 2024
4 checks passed
@zhczhong zhczhong deleted the zhicong/deep_tile_matmul branch August 29, 2024 06:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

nested matmul implementation
6 participants